mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
396 lines
10 KiB
396 lines
10 KiB
package model
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"godo/libs"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type OllamaModelsList struct {
|
|
Models []OllamaModelsInfo `json:"models"`
|
|
}
|
|
|
|
type OllamaDetails struct {
|
|
ParameterSize string `json:"parameter_size"`
|
|
QuantizationLevel string `json:"quantization_level"`
|
|
}
|
|
type OllamaModelsInfo struct {
|
|
Model string `json:"model"`
|
|
Details OllamaDetails `json:"details"`
|
|
Size int64 `json:"size"`
|
|
}
|
|
type OllamaModelDetail struct {
|
|
Parameters string `json:"parameters"`
|
|
Template string `json:"template"`
|
|
Details map[string]interface{} `json:"details"`
|
|
ModelInfo map[string]interface{} `json:"model_info"`
|
|
}
|
|
type ResModelInfo struct {
|
|
Parameters string `json:"parameters"`
|
|
Template string `json:"template"`
|
|
ContextLength int64 `json:"context_length"`
|
|
EmbeddingLength int64 `json:"embedding_length"`
|
|
Size string `json:"size"`
|
|
Quant string `json:"quant"`
|
|
Desk string `json:"desk"`
|
|
Cpu string `json:"cpu"`
|
|
Gpu string `json:"gpu"`
|
|
}
|
|
|
|
type Layer struct {
|
|
MediaType string `json:"mediaType"`
|
|
Digest string `json:"digest"`
|
|
Size int64 `json:"size"`
|
|
From string `json:"from,omitempty"`
|
|
status string
|
|
}
|
|
type ManifestV2 struct {
|
|
SchemaVersion int `json:"schemaVersion"`
|
|
MediaType string `json:"mediaType"`
|
|
Config *Layer `json:"config"`
|
|
Layers []*Layer `json:"layers"`
|
|
}
|
|
type OmodelPath struct {
|
|
Space string
|
|
LibPath string
|
|
Name string
|
|
Tag string
|
|
}
|
|
|
|
const (
|
|
KB = 1 << (10 * iota)
|
|
MB
|
|
GB
|
|
)
|
|
const (
|
|
CPU_8GB = "8GB"
|
|
CPU_16GB = "16GB"
|
|
CPU_32GB = "32GB"
|
|
GPU_6GB = "6GB"
|
|
GPU_8GB = "8GB"
|
|
GPU_12GB = "12GB"
|
|
)
|
|
|
|
func humanReadableSize(size int64) string {
|
|
units := []string{"B", "KB", "MB", "GB"}
|
|
unitIndex := 0 // Start with Bytes
|
|
for size >= 1000 && unitIndex < len(units)-1 {
|
|
size /= 1000
|
|
unitIndex++
|
|
}
|
|
|
|
switch unitIndex {
|
|
case 0, 1, 2, 3: // For B, KB, and MB, keep decimal points
|
|
return fmt.Sprintf("%d%s", size, units[unitIndex])
|
|
default:
|
|
return fmt.Sprintf("%dB", size) // Fallback for sizes less than 1B or unhandled cases
|
|
}
|
|
}
|
|
|
|
func Tagshandler(w http.ResponseWriter, r *http.Request) {
|
|
err := LoadConfig()
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Load config error")
|
|
return
|
|
}
|
|
var reqBodies []ReqBody
|
|
reqBodyMap.Range(func(key, value interface{}) bool {
|
|
rb, ok := value.(ReqBody)
|
|
if ok {
|
|
reqBodies = append(reqBodies, rb)
|
|
|
|
}
|
|
return true // 继续遍历
|
|
})
|
|
// 对reqBodies按CreatedAt降序排列
|
|
sort.Slice(reqBodies, func(i, j int) bool {
|
|
return reqBodies[i].CreatedAt.After(reqBodies[j].CreatedAt) // 降序排列
|
|
})
|
|
// 设置响应内容类型为JSON
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
// 使用json.NewEncoder将reqBodies编码为JSON并写入响应体
|
|
if err := json.NewEncoder(w).Encode(reqBodies); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
func ShowHandler(w http.ResponseWriter, r *http.Request) {
|
|
err := LoadConfig()
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Load config error")
|
|
return
|
|
}
|
|
model := r.URL.Query().Get("model")
|
|
if model == "" {
|
|
libs.ErrorMsg(w, "Model name is empty")
|
|
return
|
|
}
|
|
//log.Printf("ShowHandler: %s", model)
|
|
var reqBodies ReqBody
|
|
reqBodyMap.Range(func(key, value interface{}) bool {
|
|
rb, ok := value.(ReqBody)
|
|
if ok && rb.Model == model {
|
|
reqBodies = rb
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
//log.Printf("ShowHandler: %s", reqBodies)
|
|
// 设置响应内容类型为JSON
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
// 使用json.NewEncoder将reqBodies编码为JSON并写入响应体
|
|
if err := json.NewEncoder(w).Encode(reqBodies); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
func extractParameterSize(sizeStr string, model string) (float64, bool) {
|
|
// log.Printf("extractParameterSize: %s", sizeStr)
|
|
// log.Printf("extractParameterModel: %s", model)
|
|
// 尝试直接从原始sizeStr中提取数字,包括小数
|
|
if size, err := strconv.ParseFloat(strings.TrimSuffix(sizeStr, "B"), 64); err == nil {
|
|
return size, true
|
|
}
|
|
|
|
if parts := strings.Split(model, ":"); len(parts) > 1 {
|
|
// 确保移除 "b" 或 "B" 后缀,并尝试转换为浮点数
|
|
cleanedPart := strings.TrimSuffix(strings.ToLower(parts[1]), "b")
|
|
if size, err := strconv.ParseFloat(cleanedPart, 64); err == nil {
|
|
return size, true
|
|
}
|
|
}
|
|
|
|
return 0, false
|
|
}
|
|
|
|
func parseOllamaInfo(info OllamaModelsInfo) ResModelInfo {
|
|
res := ResModelInfo{
|
|
Size: humanReadableSize(info.Size),
|
|
Quant: info.Details.QuantizationLevel,
|
|
}
|
|
res.Desk = res.Size
|
|
paramSize, ok := extractParameterSize(info.Details.ParameterSize, info.Model)
|
|
if !ok {
|
|
res.Cpu = CPU_8GB
|
|
res.Gpu = GPU_6GB
|
|
return res
|
|
}
|
|
|
|
switch {
|
|
case paramSize < 3:
|
|
res.Cpu = CPU_8GB
|
|
res.Gpu = GPU_6GB
|
|
case paramSize < 9:
|
|
res.Cpu = CPU_16GB
|
|
res.Gpu = GPU_8GB
|
|
default:
|
|
res.Cpu = CPU_32GB
|
|
res.Gpu = GPU_12GB
|
|
}
|
|
|
|
return res
|
|
}
|
|
func getOllamaModels() ([]OllamaModelsInfo, error) {
|
|
req, err := http.Get(GetOllamaUrl() + "/api/tags")
|
|
res := []OllamaModelsInfo{}
|
|
if err != nil {
|
|
return res, fmt.Errorf("failed to create request")
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
defer req.Body.Close()
|
|
body, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
return res, fmt.Errorf("failed to read response body")
|
|
}
|
|
rest := OllamaModelsList{}
|
|
if err := json.Unmarshal(body, &rest); err != nil {
|
|
return res, fmt.Errorf("failed to unmarshal response body")
|
|
}
|
|
return rest.Models, nil
|
|
|
|
}
|
|
func setOllamaInfo(w http.ResponseWriter, r *http.Request, reqBody ReqBody) {
|
|
model := reqBody.Model
|
|
postQuery := map[string]interface{}{
|
|
"model": model,
|
|
}
|
|
url := GetOllamaUrl() + "/api/pull"
|
|
ForwardHandler(w, r, postQuery, url, "POST")
|
|
details, err := getOllamaInfo(r, model)
|
|
//log.Printf("details is %v", details)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "get ollama info error: ")
|
|
return
|
|
}
|
|
|
|
modelList, err := getOllamaModels()
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Load ollama error: ")
|
|
return
|
|
}
|
|
if len(modelList) < 1 {
|
|
libs.ErrorMsg(w, "Load ollama error: ")
|
|
return
|
|
}
|
|
for _, model := range modelList {
|
|
if model.Model == reqBody.Model {
|
|
oinfo := parseOllamaInfo(model)
|
|
architecture := details.ModelInfo["general.architecture"].(string)
|
|
contextLength := details.ModelInfo[architecture+".context_length"]
|
|
embeddingLength := details.ModelInfo[architecture+".embedding_length"]
|
|
info := map[string]interface{}{
|
|
"size": oinfo.Size,
|
|
"quant": oinfo.Quant,
|
|
"Desk": oinfo.Desk,
|
|
"cpu": oinfo.Cpu,
|
|
"gpu": oinfo.Gpu,
|
|
"pb": model.Details.ParameterSize,
|
|
"template": details.Template,
|
|
"parameters": details.Parameters,
|
|
"context_length": contextLength,
|
|
"embedding_length": embeddingLength,
|
|
}
|
|
paths, err := getManifests(model.Model)
|
|
if err != nil {
|
|
log.Printf("Error parsing Manifests: %v", err)
|
|
continue
|
|
}
|
|
|
|
reqBody.Info = info
|
|
reqBody.Paths = paths
|
|
reqBody.Status = "success"
|
|
reqBody.CreatedAt = time.Now()
|
|
if err := SetModel(reqBody); err != nil {
|
|
libs.ErrorMsg(w, "Set model error")
|
|
return
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
func getOllamaInfo(r *http.Request, model string) (OllamaModelDetail, error) {
|
|
infoQuery := map[string]interface{}{
|
|
"name": model,
|
|
}
|
|
res := OllamaModelDetail{}
|
|
url := GetOllamaUrl() + "/api/show"
|
|
payloadBytes, err := json.Marshal(infoQuery)
|
|
if err != nil {
|
|
return res, fmt.Errorf("json payload error: %w", err)
|
|
}
|
|
// 创建POST请求,复用原始请求的上下文(如Cookies)
|
|
req, err := http.NewRequestWithContext(r.Context(), "POST", url, bytes.NewBuffer(payloadBytes))
|
|
if err != nil {
|
|
return res, fmt.Errorf("couldn't create req context: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// 发送请求
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return res, fmt.Errorf("couldn't read response body: %w", err)
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &res); err != nil {
|
|
return res, fmt.Errorf("failed to unmarshal response body")
|
|
}
|
|
return res, nil
|
|
}
|
|
func getOpName(model string) OmodelPath {
|
|
libPath := "library"
|
|
modelName := model
|
|
modelTags := "latest"
|
|
if strings.Contains(modelName, ":") {
|
|
names := strings.Split(model, ":")
|
|
modelName = names[0]
|
|
modelTags = names[1]
|
|
}
|
|
|
|
if strings.Contains(modelName, "/") {
|
|
names := strings.Split(modelName, "/")
|
|
libPath = names[0]
|
|
modelName = names[1]
|
|
}
|
|
return OmodelPath{
|
|
Space: "registry.ollama.ai",
|
|
LibPath: libPath,
|
|
Name: modelName,
|
|
Tag: modelTags,
|
|
}
|
|
}
|
|
func getManifests(model string) ([]string, error) {
|
|
res := []string{}
|
|
opName := getOpName(model)
|
|
modelsDir, err := getOModelsDir()
|
|
if err != nil {
|
|
return res, fmt.Errorf("failed to get user home directory: %w", err)
|
|
}
|
|
manifestsFile := filepath.Join(modelsDir, "manifests", opName.Space, opName.LibPath, opName.Name, opName.Tag)
|
|
if !libs.PathExists(manifestsFile) {
|
|
return res, fmt.Errorf("failed to get manifests file: %w", err)
|
|
}
|
|
res = append(res, manifestsFile)
|
|
var manifest ManifestV2
|
|
f, err := os.Open(manifestsFile)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
defer f.Close()
|
|
|
|
sha256sum := sha256.New()
|
|
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
|
|
return res, err
|
|
}
|
|
filename, err := GetBlobsPath(manifest.Config.Digest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res = append(res, filename)
|
|
for _, layer := range manifest.Layers {
|
|
filename, err := GetBlobsPath(layer.Digest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res = append(res, filename)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func GetBlobsPath(digest string) (string, error) {
|
|
dir, err := getOModelsDir()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// only accept actual sha256 digests
|
|
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
|
re := regexp.MustCompile(pattern)
|
|
|
|
if digest != "" && !re.MatchString(digest) {
|
|
return "", errors.New("invalid digest format")
|
|
}
|
|
|
|
digest = strings.ReplaceAll(digest, ":", "-")
|
|
path := filepath.Join(dir, "blobs", digest)
|
|
return path, nil
|
|
}
|
|
|