From 877198da9ae6e7a97d9b6dc52e2ca352395a7c91 Mon Sep 17 00:00:00 2001 From: godo Date: Sat, 11 Jan 2025 14:36:20 +0800 Subject: [PATCH] add res --- godo/ai/server/chat.go | 120 +++++++++++++++++++++++--------------- godo/ai/server/common.go | 31 ++++++++++ godo/ai/server/llms.go | 3 +- godo/ai/types/response.go | 102 ++++++++++++++++++++++++++++++++ godo/ai/vector/vector.go | 23 ++++---- godo/model/vec_doc.go | 3 + godo/office/document.go | 24 ++++++++ 7 files changed, 247 insertions(+), 59 deletions(-) create mode 100644 godo/ai/types/response.go diff --git a/godo/ai/server/chat.go b/godo/ai/server/chat.go index d601060..2c7f1dd 100644 --- a/godo/ai/server/chat.go +++ b/godo/ai/server/chat.go @@ -4,68 +4,91 @@ import ( "encoding/json" "fmt" "godo/ai/search" + "godo/ai/types" "godo/libs" "godo/model" "godo/office" "log" "net/http" + "strings" "time" ) -type ChatRequest struct { - Model string `json:"model"` - Engine string `json:"engine"` - Stream bool `json:"stream"` - WebSearch bool `json:"webSearch"` - FileContent string `json:"fileContent"` - FileName string `json:"fileName"` - Options map[string]interface{} `json:"options"` - Messages []Message `json:"messages"` - KnowledgeId uint `json:"knowledgeId"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Images []string `json:"images"` -} - func ChatHandler(w http.ResponseWriter, r *http.Request) { var url string - var req ChatRequest + var req types.ChatRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { libs.ErrorMsg(w, "the chat request error:"+err.Error()) return } + headers, url, err := GetHeadersAndUrl(req, "chat") + // log.Printf("url: %s", url) + // log.Printf("headers: %v", headers) + if err != nil { + libs.ErrorMsg(w, "the chat request header or url errors:"+err.Error()) + return + } if req.WebSearch { - err = ChatWithWeb(&req) + searchRes, err := ChatWithWeb(&req) if err != nil { log.Printf("the chat with web error:%v", err) + libs.ErrorMsg(w, err.Error()) + return + } + res, err := SendChat(w, r, req, url, headers) + if err != nil { + log.Printf("the chat with web error:%v", err) + libs.ErrorMsg(w, err.Error()) + return + } + for _, s := range searchRes { + res.WebSearch = append(res.WebSearch, types.WebSearchResult{Title: s.Title, Content: s.Content, Link: s.Url}) } + libs.SuccessMsg(w, res, "") + return } if req.FileContent != "" { err = ChatWithFile(&req) if err != nil { log.Printf("the chat with file error:%v", err) } + res, err := SendChat(w, r, req, url, headers) + if err != nil { + log.Printf("the chat with web error:%v", err) + libs.ErrorMsg(w, err.Error()) + return + } + libs.SuccessMsg(w, res, "") + return } if req.KnowledgeId != 0 { - err = ChatWithKnowledge(&req) + resk, err := ChatWithKnowledge(&req) if err != nil { log.Printf("the chat with knowledge error:%v", err) } - } - headers, url, err := GetHeadersAndUrl(req, "chat") - // log.Printf("url: %s", url) - // log.Printf("headers: %v", headers) - if err != nil { - libs.ErrorMsg(w, "the chat request header or url errors:"+err.Error()) + res, err := SendChat(w, r, req, url, headers) + if err != nil { + log.Printf("the chat with web error:%v", err) + libs.ErrorMsg(w, err.Error()) + return + } + basePath, err := libs.GetOsDir() + if err != nil { + libs.ErrorMsg(w, "get vector db path error:"+err.Error()) + return + } + for _, s := range resk { + s.FilePath = strings.TrimPrefix(s.FilePath, basePath) + res.Documents = append(res.Documents, types.AskDocResponse{Content: s.Content, Score: s.Score, FilePath: s.FilePath, FileName: s.FileName}) + } + libs.SuccessMsg(w, res, "") return } + ForwardHandler(w, r, req, url, headers, "POST") } -func ChatWithFile(req *ChatRequest) error { +func ChatWithFile(req *types.ChatRequest) error { fileContent, err := office.ProcessBase64File(req.FileContent, req.FileName) if err != nil { return err @@ -76,13 +99,14 @@ func ChatWithFile(req *ChatRequest) error { } userQuestion := fmt.Sprintf("请对\n%s\n的内容进行分析,给出对用户输入的回答: %s", fileContent, lastMessage) log.Printf("the search file is %v", userQuestion) - req.Messages = append([]Message{}, Message{Role: "user", Content: userQuestion}) + req.Messages = append([]types.Message{}, types.Message{Role: "user", Content: userQuestion}) return nil } -func ChatWithKnowledge(req *ChatRequest) error { +func ChatWithKnowledge(req *types.ChatRequest) ([]types.AskDocResponse, error) { + var res []types.AskDocResponse lastMessage, err := GetLastMessage(*req) if err != nil { - return err + return res, err } askrequest := model.AskRequest{ ID: req.KnowledgeId, @@ -90,35 +114,37 @@ func ChatWithKnowledge(req *ChatRequest) error { } var knowData model.VecList if err := model.Db.First(&knowData, askrequest.ID).Error; err != nil { - return fmt.Errorf("the knowledge id is not exist") + return res, fmt.Errorf("the knowledge id is not exist") } //var filterDocs filterDocs := []string{askrequest.Input} // 获取嵌入向量 resList, err := GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, filterDocs) if err != nil { - return fmt.Errorf("the embeddings get error:%v", err) + return res, fmt.Errorf("the embeddings get error:%v", err) } - res, err := model.AskDocument(askrequest.ID, resList[0]) + resk, err := model.AskDocument(askrequest.ID, resList[0]) if err != nil { - return fmt.Errorf("the ask document error:%v", err) + return res, fmt.Errorf("the ask document error:%v", err) } msg := "" - for _, res := range res { - msg += fmt.Sprintf("- %s\n", res.Content) + for _, item := range resk { + msg += fmt.Sprintf("- %s\n", item.Content) + res = append(res, types.AskDocResponse{Content: item.Content, Score: item.Score, FilePath: item.FilePath, FileName: item.FileName}) } prompt := fmt.Sprintf(`从文档\n\"\"\"\n%s\n\"\"\"\n中找问题\n\"\"\"\n%s\n\"\"\"\n的答案,找到答案就使用文档语句回答问题,找不到答案就用自身知识回答并且告诉用户该信息不是来自文档。\n不要复述问题,直接开始回答。`, msg, lastMessage) - req.Messages = append([]Message{}, Message{Role: "user", Content: prompt}) - return nil + req.Messages = append([]types.Message{}, types.Message{Role: "user", Content: prompt}) + return res, nil } -func ChatWithWeb(req *ChatRequest) error { +func ChatWithWeb(req *types.ChatRequest) ([]search.Entity, error) { lastMessage, err := GetLastMessage(*req) + var searchRequest []search.Entity if err != nil { - return err + return searchRequest, err } - searchRequest := search.SearchWeb(lastMessage) + searchRequest = search.SearchWeb(lastMessage) if len(searchRequest) == 0 { - return fmt.Errorf("the search web is empty") + return searchRequest, fmt.Errorf("the search web is empty") } var inputPrompt string for _, search := range searchRequest { @@ -139,11 +165,11 @@ func ChatWithWeb(req *ChatRequest) error { `, inputPrompt, currentDate, lastMessage) //log.Printf("the search web is %v", searchPrompt) // req.Messages = append([]Message{}, Message{Role: "assistant", Content: searchPrompt}) - req.Messages = append([]Message{}, Message{Role: "user", Content: searchPrompt}) - return nil + req.Messages = append([]types.Message{}, types.Message{Role: "user", Content: searchPrompt}) + return searchRequest, nil } func EmbeddingHandler(w http.ResponseWriter, r *http.Request) { - var req ChatRequest + var req types.ChatRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { libs.ErrorMsg(w, "the chat request error:"+err.Error()) @@ -156,7 +182,7 @@ func EmbeddingHandler(w http.ResponseWriter, r *http.Request) { } ForwardHandler(w, r, req, url, headers, "POST") } -func GetLastMessage(req ChatRequest) (string, error) { +func GetLastMessage(req types.ChatRequest) (string, error) { if len(req.Messages) == 0 { return "", fmt.Errorf("the messages is empty") } diff --git a/godo/ai/server/common.go b/godo/ai/server/common.go index 8fe9751..417b85c 100644 --- a/godo/ai/server/common.go +++ b/godo/ai/server/common.go @@ -3,12 +3,43 @@ package server import ( "bytes" "encoding/json" + "fmt" + "godo/ai/types" "godo/libs" "io" "log" "net/http" ) +func SendChat(w http.ResponseWriter, r *http.Request, reqBody interface{}, url string, headers map[string]string) (types.OpenAIResponse, error) { + var res types.OpenAIResponse + payloadBytes, err := json.Marshal(reqBody) + if err != nil { + return res, fmt.Errorf("Error marshaling payload") + } + // 创建POST请求,复用原始请求的上下文(如Cookies) + req, err := http.NewRequestWithContext(r.Context(), "POST", url, bytes.NewBuffer(payloadBytes)) + if err != nil { + return res, fmt.Errorf("Failed to create request") + } + for key, value := range headers { + req.Header.Set(key, value) + } + //req.Header.Set("Content-Type", "application/json") + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return res, fmt.Errorf("Failed to send request") + } + defer resp.Body.Close() + + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return res, fmt.Errorf("Failed to decode response body") + } + return res, nil +} func ForwardHandler(w http.ResponseWriter, r *http.Request, reqBody interface{}, url string, headers map[string]string, method string) { payloadBytes, err := json.Marshal(reqBody) if err != nil { diff --git a/godo/ai/server/llms.go b/godo/ai/server/llms.go index 805d155..6d8ba5b 100644 --- a/godo/ai/server/llms.go +++ b/godo/ai/server/llms.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "godo/ai/types" "godo/libs" "log" ) @@ -24,7 +25,7 @@ var OpenAIApiMaps = map[string]string{ "siliconflow": "https://api.siliconflow.cn/v1", } -func GetHeadersAndUrl(req ChatRequest, chattype string) (map[string]string, string, error) { +func GetHeadersAndUrl(req types.ChatRequest, chattype string) (map[string]string, string, error) { // engine, ok := req["engine"].(string) // if !ok { // return nil, "", fmt.Errorf("invalid engine field in request") diff --git a/godo/ai/types/response.go b/godo/ai/types/response.go new file mode 100644 index 0000000..e5a247e --- /dev/null +++ b/godo/ai/types/response.go @@ -0,0 +1,102 @@ +package types + +type ChatRequest struct { + Model string `json:"model"` + Engine string `json:"engine"` + Stream bool `json:"stream"` + WebSearch bool `json:"webSearch"` + FileContent string `json:"fileContent"` + FileName string `json:"fileName"` + Options map[string]interface{} `json:"options"` + Messages []Message `json:"messages"` + KnowledgeId uint `json:"knowledgeId"` +} + +type InvokeResp struct { + RequestID string `json:"requestId"` + Content string `json:"content"` + Problems []string `json:"problems"` + DocumentSlices []struct { + Document Document `json:"document"` + SliceInfo []Slice `json:"slice_info"` + HidePositions bool `json:"hide_positions"` + Images []Image `json:"images"` + } `json:"documents"` +} +type AskDocResponse struct { + Content string `json:"content"` + Score float32 `json:"score"` + FilePath string `json:"file_path"` + FileName string `json:"file_name"` +} +type Document struct { + ID string `json:"id"` + Name string `json:"name"` + URL string `json:"url"` + Dtype int `json:"dtype"` +} + +type Slice struct { + DocumentID string `json:"document_id"` + Position *Position `json:"position,omitempty"` + Line int `json:"line,omitempty"` + SheetName string `json:"sheet_name,omitempty"` + Text string `json:"text"` +} +type Position struct { + X0 float64 `json:"x0"` + X1 float64 `json:"x1"` + Top float64 `json:"top"` + Bottom float64 `json:"bottom"` + Page int `json:"page"` + Height float64 `json:"height"` + Width float64 `json:"width"` +} +type Image struct { + Text string `json:"text"` + CosURL string `json:"cos_url"` +} +type OpenAIResponse struct { + ID string `json:"id"` + Created int64 `json:"created"` + RequestID string `json:"request_id"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` + WebSearch []WebSearchResult `json:"web_search,omitempty"` + Documents []AskDocResponse `json:"documents"` +} +type Choice struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Message Message `json:"message"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + Images []string `json:"images"` +} + +type ToolCall struct { + Function FunctionRes `json:"function"` + ID string `json:"id"` + Type string `json:"type"` +} +type FunctionRes struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type WebSearchResult struct { + Icon string `json:"icon"` + Title string `json:"title"` + Link string `json:"link"` + Media string `json:"media"` + Content string `json:"content"` +} +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} diff --git a/godo/ai/vector/vector.go b/godo/ai/vector/vector.go index 3f621aa..4a0135a 100644 --- a/godo/ai/vector/vector.go +++ b/godo/ai/vector/vector.go @@ -280,9 +280,9 @@ func handleGodoosFile(filePath string, knowledgeId uint) error { // 检查是否为 .godoos 文件 if strings.HasPrefix(baseName, ".godoos.") { // 去掉 .godoos. 前缀和 .json 后缀 - //fileName := strings.TrimSuffix(strings.TrimPrefix(baseName, ".godoos."), ".json") - // 提取实际文件名部分 - //actualFileName := extractFileName(fileName) + fileName := strings.TrimSuffix(strings.TrimPrefix(baseName, ".godoos."), ".json") + //提取实际文件名部分 + actualFileName := extractFileName(fileName) // 读取文件内容 content, err := os.ReadFile(filePath) @@ -330,6 +330,7 @@ func handleGodoosFile(filePath string, knowledgeId uint) error { vectordoc := model.VecDoc{ Content: res, FilePath: filePath, + FileName: actualFileName, ListID: knowledgeId, } vectordocs = append(vectordocs, vectordoc) @@ -347,11 +348,11 @@ func handleGodoosFile(filePath string, knowledgeId uint) error { } } -// func extractFileName(fileName string) string { -// // 假设文件名格式为:21.GodoOS企业版介绍 -// parts := strings.SplitN(fileName, ".", 3) -// if len(parts) < 2 { -// return fileName -// } -// return parts[1] -// } +func extractFileName(fileName string) string { + // 假设文件名格式为:21.GodoOS企业版介绍 + parts := strings.SplitN(fileName, ".", 3) + if len(parts) < 2 { + return fileName + } + return parts[1] +} diff --git a/godo/model/vec_doc.go b/godo/model/vec_doc.go index fe41a1f..98bd791 100644 --- a/godo/model/vec_doc.go +++ b/godo/model/vec_doc.go @@ -12,6 +12,7 @@ type VecDoc struct { gorm.Model Content string `json:"content"` FilePath string `json:"file_path" gorm:"not null"` + FileName string `json:"file_name"` ListID uint `json:"list_id"` } @@ -88,6 +89,7 @@ type AskDocResponse struct { Content string `json:"content"` Score float32 `json:"score"` FilePath string `json:"file_path"` + FileName string `json:"file_name"` } type AskRequest struct { ID uint `json:"id"` @@ -142,6 +144,7 @@ func AskDocument(listId uint, query []float32) ([]AskDocResponse, error) { Content: doc.Content, Score: results[i].Distance, FilePath: doc.FilePath, + FileName: doc.FileName, }) } // 按 Score 降序排序 diff --git a/godo/office/document.go b/godo/office/document.go index d8eaa14..423f334 100644 --- a/godo/office/document.go +++ b/godo/office/document.go @@ -218,7 +218,31 @@ func GetDocument(pathname string) (*Document, error) { _, err = getContentData(&data, html2txt) case ".json": _, err = getContentData(&data, json2txt) + // case ".md", ".txt", ".py", ".java", ".c", ".cpp", ".h", ".hpp", ".js", ".ts", ".go", ".rb", ".php", ".swift", ".kt", ".scala", ".rust", ".perl", ".bash", ".sh", ".lua", ".dart", ".r", ".matlab", ".pl", ".pm", ".tcl", ".sql", ".groovy", ".cs", ".vb", ".fs", ".hs", ".erl", ".elixir", ".crystal", ".nim", ".d", ".coffeescript", ".typescript", ".vue", ".svelte", ".jsx", ".tsx", ".html", ".css", ".scss", ".less", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".makefile", ".dockerfile", ".gitignore", ".editorconfig", ".prettierrc", ".eslintrc", ".babelrc", ".jsonp", ".graphql", ".proto", ".plist", ".edn": + // _, err = getContentData(&data, text2txt) + default: + _, err = getContentData(&data, text2txt) + } + if err != nil { + return &data, err + } + return &data, nil +} +func GetTxtDoc(pathname string) (*Document, error) { + if !libs.PathExists(pathname) { + return nil, fmt.Errorf("file does not exist: %s", pathname) + } + abPath, err := filepath.Abs(pathname) + if err != nil { + return nil, err + } + filename := path.Base(pathname) + data := Document{path: pathname, RePath: abPath, Title: filename} + _, err = getFileInfoData(&data) + if err != nil { + return &data, err } + _, err = getContentData(&data, text2txt) if err != nil { return &data, err }