From b2b3bab4f3d3466f052c7ea812aadfc7688ef580 Mon Sep 17 00:00:00 2001 From: godo Date: Sat, 30 Nov 2024 23:28:53 +0800 Subject: [PATCH] change ai --- frontend/src/system/config.ts | 21 ++---- godo/ai/api/openai.go | 30 --------- godo/ai/{api => llms}/gitee.go | 2 +- godo/ai/llms/ollama.go | 8 +++ godo/ai/llms/openai.go | 114 +++++++++++++++++++++++++++++++++ godo/ai/server/chat.go | 52 +++++++++++++-- godo/ai/server/common.go | 7 +- godo/ai/server/down.go | 2 +- godo/ai/server/ollama.go | 4 +- 9 files changed, 183 insertions(+), 57 deletions(-) delete mode 100644 godo/ai/api/openai.go rename godo/ai/{api => llms}/gitee.go (97%) create mode 100644 godo/ai/llms/ollama.go create mode 100644 godo/ai/llms/openai.go diff --git a/frontend/src/system/config.ts b/frontend/src/system/config.ts index 9ca9cbc..d7e0e19 100644 --- a/frontend/src/system/config.ts +++ b/frontend/src/system/config.ts @@ -160,22 +160,11 @@ export const getSystemConfig = (ifset = false) => { if (!config.openaiUrl) { config.openaiUrl = 'https://api.openai.com/v1' } - if (!config.aiKey) { - config.aiKey = { - "openai": "", - "gitee": "", - // "google": "", - // "baidu": "", - // "ali": "", - // "tencent": "", - // "bigmodel": "", - // "xai": "", - // "azure": "", - // "stability": "", - // "claude": "", - // "groq": "" - - } + if (!config.openaiSecret) { + config.openaiSecret = "" + } + if(!config.giteeSecret){ + config.giteeSecret = "" } // 初始化桌面快捷方式列表,若本地存储中已存在则不进行覆盖 if (!config.desktopList) { diff --git a/godo/ai/api/openai.go b/godo/ai/api/openai.go deleted file mode 100644 index 578fe17..0000000 --- a/godo/ai/api/openai.go +++ /dev/null @@ -1,30 +0,0 @@ -package api - -import ( - "fmt" - "godo/libs" -) - -// 获取 OpenAI 聊天 API 的 URL -func GetOpenAIChatUrl() string { - return "https://api.openai.com/v1/chat/completions" -} - -// 获取 OpenAI 文本嵌入 API 的 URL -func GetOpenAIEmbeddingUrl() string { - return "https://api.openai.com/v1/embeddings" -} - -// 获取 OpenAI 文本转图像 API 的 URL -func GetOpenAIText2ImgUrl() string { - return "https://api.openai.com/v1/images/generations" -} - -// 获取 OpenAI 密钥 -func GetOpenAISecret() (string, error) { - secret, has := libs.GetConfig("openaiSecret") - if !has { - return "", fmt.Errorf("the openai secret is not set") - } - return secret.(string), nil -} diff --git a/godo/ai/api/gitee.go b/godo/ai/llms/gitee.go similarity index 97% rename from godo/ai/api/gitee.go rename to godo/ai/llms/gitee.go index af9f911..9ba6d5a 100644 --- a/godo/ai/api/gitee.go +++ b/godo/ai/llms/gitee.go @@ -1,4 +1,4 @@ -package api +package llms import ( "fmt" diff --git a/godo/ai/llms/ollama.go b/godo/ai/llms/ollama.go new file mode 100644 index 0000000..713d316 --- /dev/null +++ b/godo/ai/llms/ollama.go @@ -0,0 +1,8 @@ +package llms + +func GetOllamaChatUrl(url string) string { + return url + "/v1/chat/completions" +} +func GetOllamaEmbeddingUrl(url string) string { + return url + "/api/embeddings" +} diff --git a/godo/ai/llms/openai.go b/godo/ai/llms/openai.go new file mode 100644 index 0000000..8770aa2 --- /dev/null +++ b/godo/ai/llms/openai.go @@ -0,0 +1,114 @@ +package llms + +import ( + "fmt" + "godo/libs" +) + +// ollama openai deepseek bigmodel alibaba 01ai cloudflare groq mistral anthropic llamafamily +var OpenAIApiMaps = map[string]string{ + //"openai": GetOpenAIUrl(), + "deepseek": "https://api.deepseek.com/v1", + "bigmodel": "https://open.bigmodel.cn/api/paas/v4", + "alibaba": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "01ai": "https://api.lingyiwanwu.com/v1", + "groq": "https://api.groq.com/openai/v1", + "mistral": "https://api.mistral.ai/v1", + "anthropic": "https://api.anthropic.com/v1", + "llamafamily": "https://api.atomecho.cn/v1", +} + +// 获取 OpenAI 聊天 API 的 URL +func GetOpenAIChatUrl(types string) (map[string]string, string, error) { + aiUrl, err := GetAIUrl(types) + if err != nil { + return nil, "", err + } + headers, err := GetOpenAIHeaders(types) + if err != nil { + return nil, "", err + } + return headers, aiUrl + "/chat/completions", nil +} + +// 获取 OpenAI 文本嵌入 API 的 URL +func GetOpenAIEmbeddingUrl(types string) (map[string]string, string, error) { + aiUrl, err := GetAIUrl(types) + if err != nil { + return nil, "", err + } + headers, err := GetOpenAIHeaders(types) + if err != nil { + return nil, "", err + } + return headers, aiUrl + "/embeddings", nil +} + +// 获取 OpenAI 文本转图像 API 的 URL +func GetOpenAIText2ImgUrl(types string) (map[string]string, string, error) { + aiUrl, err := GetAIUrl(types) + if err != nil { + return nil, "", err + } + headers, err := GetOpenAIHeaders(types) + if err != nil { + return nil, "", err + } + return headers, aiUrl + "/images/generations", nil +} + +func GetAIUrl(types string) (string, error) { + if types == "openai" { + return GetOpenAIUrl(), nil + } else if types == "cloudflare" { + return GetCloudflareUrl() + } else { + url, exists := OpenAIApiMaps[types] + if !exists { + return "", fmt.Errorf("the " + types + " url is not set") + } else { + return url, nil + } + } +} +func GetOpenAIHeaders(types string) (map[string]string, error) { + secret, err := GetOpenAISecret(types) + if err != nil { + return nil, err + } + return map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + secret, + }, nil +} + +// 获取 OpenAI 密钥 +func GetOpenAISecret(types string) (string, error) { + secret, has := libs.GetConfig(types + "Secret") + if !has { + return "", fmt.Errorf("the " + types + " secret is not set") + } + return secret.(string), nil +} +func GetOpenAIUserId(types string) (string, error) { + userId, has := libs.GetConfig(types + "UserId") + if !has { + return "", fmt.Errorf("the " + types + " user id is not set") + } + return userId.(string), nil +} +func GetOpenAIUrl() string { + openaiUrl, ok := libs.GetConfig("openaiUrl") + if ok { + return openaiUrl.(string) + } else { + return "https://api.openai.com/v1" + } +} +func GetCloudflareUrl() (string, error) { + userId, err := GetOpenAIUserId("cloudflare") + if err != nil { + return "", err + } + return "https://api.cloudflare.com/client/v4/accounts/" + userId + "/ai/v1", nil +} diff --git a/godo/ai/server/chat.go b/godo/ai/server/chat.go index 43bef73..c5f6a29 100644 --- a/godo/ai/server/chat.go +++ b/godo/ai/server/chat.go @@ -2,20 +2,62 @@ package server import ( "encoding/json" + "godo/ai/llms" "godo/libs" "net/http" ) func ChatHandler(w http.ResponseWriter, r *http.Request) { - url := GetOllamaUrl() + "/v1/chat/completions" - var request interface{} - err := json.NewDecoder(r.Body).Decode(&request) + // url := GetOllamaUrl() + "/v1/chat/completions" + var url string + var req map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { libs.ErrorMsg(w, err.Error()) return } - ForwardHandler(w, r, request, url, "POST") + engine, ok := req["engine"].(string) + if !ok { + libs.ErrorMsg(w, "Invalid engine field in request") + return + } + model, ok := req["model"].(string) + if !ok { + libs.ErrorMsg(w, "Invalid model field in request") + return + } + var headers map[string]string + switch engine { + case "ollama": + ollamaUrl := GetOllamaUrl() + url = llms.GetOllamaChatUrl(ollamaUrl) + headers = map[string]string{ + "Content-Type": "application/json", + } + case "gitee": + url = llms.GetGiteeChatUrl(model) + case "openai": + headers, url, err = llms.GetOpenAIChatUrl("openai") + if err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + case "cloudflare": + headers, url, err = llms.GetOpenAIChatUrl("cloudflare") + if err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + default: + headers, url, err = llms.GetOpenAIChatUrl("openai") // 默认URL + if err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + } + ForwardHandler(w, r, req, url, headers, "POST") } + func EmbeddingHandler(w http.ResponseWriter, r *http.Request) { url := GetOllamaUrl() + "/api/embeddings" var request interface{} @@ -24,5 +66,5 @@ func EmbeddingHandler(w http.ResponseWriter, r *http.Request) { libs.ErrorMsg(w, err.Error()) return } - ForwardHandler(w, r, request, url, "POST") + ForwardHandler(w, r, request, url, nil, "POST") } diff --git a/godo/ai/server/common.go b/godo/ai/server/common.go index f60f979..8fe9751 100644 --- a/godo/ai/server/common.go +++ b/godo/ai/server/common.go @@ -9,7 +9,7 @@ import ( "net/http" ) -func ForwardHandler(w http.ResponseWriter, r *http.Request, reqBody interface{}, url string, method string) { +func ForwardHandler(w http.ResponseWriter, r *http.Request, reqBody interface{}, url string, headers map[string]string, method string) { payloadBytes, err := json.Marshal(reqBody) if err != nil { libs.ErrorMsg(w, "Error marshaling payload") @@ -21,7 +21,10 @@ func ForwardHandler(w http.ResponseWriter, r *http.Request, reqBody interface{}, libs.ErrorMsg(w, "Failed to create request") return } - req.Header.Set("Content-Type", "application/json") + for key, value := range headers { + req.Header.Set(key, value) + } + //req.Header.Set("Content-Type", "application/json") // 发送请求 client := &http.Client{} diff --git a/godo/ai/server/down.go b/godo/ai/server/down.go index efe8251..52ee937 100644 --- a/godo/ai/server/down.go +++ b/godo/ai/server/down.go @@ -223,7 +223,7 @@ func DeleteFileHandle(w http.ResponseWriter, r *http.Request) { postQuery := map[string]interface{}{"name": reqBody.Model} url := GetOllamaUrl() + "/api/delete" - ForwardHandler(w, r, postQuery, url, "DELETE") + ForwardHandler(w, r, postQuery, url, nil, "DELETE") return } delUrls(reqBody.Info.URL) diff --git a/godo/ai/server/ollama.go b/godo/ai/server/ollama.go index b157251..a7955e4 100644 --- a/godo/ai/server/ollama.go +++ b/godo/ai/server/ollama.go @@ -177,7 +177,7 @@ func setOllamaInfo(w http.ResponseWriter, r *http.Request, reqBody types.ReqBody "model": model, } url := GetOllamaUrl() + "/api/pull" - ForwardHandler(w, r, postQuery, url, "POST") + ForwardHandler(w, r, postQuery, url, nil, "POST") details, err := getOllamaInfo(r, model) //log.Printf("details is %v", details) if err != nil { @@ -366,7 +366,7 @@ func ConvertOllama(w http.ResponseWriter, r *http.Request, req types.ReqBody) { "name": req.Model, "modelfile": modelFile, } - ForwardHandler(w, r, postParams, url, "POST") + ForwardHandler(w, r, postParams, url, nil, "POST") modelDir, err := config.GetModelDir(req.Model) if err != nil { libs.ErrorMsg(w, "GetModelDir")