mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
192 lines
5.8 KiB
192 lines
5.8 KiB
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"godo/ai/search"
|
|
"godo/ai/types"
|
|
"godo/libs"
|
|
"godo/model"
|
|
"godo/office"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
func ChatHandler(w http.ResponseWriter, r *http.Request) {
|
|
var url string
|
|
var req types.ChatRequest
|
|
err := json.NewDecoder(r.Body).Decode(&req)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "the chat request error:"+err.Error())
|
|
return
|
|
}
|
|
headers, url, err := GetHeadersAndUrl(req, "chat")
|
|
// log.Printf("url: %s", url)
|
|
// log.Printf("headers: %v", headers)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "the chat request header or url errors:"+err.Error())
|
|
return
|
|
}
|
|
if req.WebSearch {
|
|
searchRes, err := ChatWithWeb(&req)
|
|
if err != nil {
|
|
log.Printf("the chat with web error:%v", err)
|
|
libs.ErrorMsg(w, err.Error())
|
|
return
|
|
}
|
|
res, err := SendChat(w, r, req, url, headers)
|
|
if err != nil {
|
|
log.Printf("the chat with web error:%v", err)
|
|
libs.ErrorMsg(w, err.Error())
|
|
return
|
|
}
|
|
for _, s := range searchRes {
|
|
res.WebSearch = append(res.WebSearch, types.WebSearchResult{Title: s.Title, Content: s.Content, Link: s.Url})
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(res)
|
|
return
|
|
}
|
|
if req.FileContent != "" {
|
|
err = ChatWithFile(&req)
|
|
if err != nil {
|
|
log.Printf("the chat with file error:%v", err)
|
|
}
|
|
res, err := SendChat(w, r, req, url, headers)
|
|
if err != nil {
|
|
log.Printf("the chat with web error:%v", err)
|
|
libs.ErrorMsg(w, err.Error())
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(res)
|
|
return
|
|
}
|
|
if req.KnowledgeId != 0 {
|
|
resk, err := ChatWithKnowledge(&req)
|
|
if err != nil {
|
|
log.Printf("the chat with knowledge error:%v", err)
|
|
}
|
|
res, err := SendChat(w, r, req, url, headers)
|
|
if err != nil {
|
|
log.Printf("the chat with web error:%v", err)
|
|
libs.ErrorMsg(w, err.Error())
|
|
return
|
|
}
|
|
|
|
for _, s := range resk {
|
|
// s.FilePath = strings.TrimPrefix(s.FilePath, basePath)
|
|
res.Documents = append(res.Documents, types.AskDocResponse{Content: s.Content, Score: s.Score, FilePath: s.FilePath, FileName: s.FileName})
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(res)
|
|
return
|
|
}
|
|
|
|
ForwardHandler(w, r, req, url, headers, "POST")
|
|
}
|
|
func ChatWithFile(req *types.ChatRequest) error {
|
|
fileContent, err := office.ProcessBase64File(req.FileContent, req.FileName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
lastMessage, err := GetLastMessage(*req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
userQuestion := fmt.Sprintf("请对\n%s\n的内容进行分析,给出对用户输入的回答: %s", fileContent, lastMessage)
|
|
log.Printf("the search file is %v", userQuestion)
|
|
req.Messages = append([]types.Message{}, types.Message{Role: "user", Content: userQuestion})
|
|
return nil
|
|
}
|
|
func ChatWithKnowledge(req *types.ChatRequest) ([]types.AskDocResponse, error) {
|
|
var res []types.AskDocResponse
|
|
lastMessage, err := GetLastMessage(*req)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
askrequest := model.AskRequest{
|
|
ID: req.KnowledgeId,
|
|
Input: lastMessage,
|
|
}
|
|
var knowData model.VecList
|
|
if err := model.Db.First(&knowData, askrequest.ID).Error; err != nil {
|
|
return res, fmt.Errorf("the knowledge id is not exist")
|
|
}
|
|
//var filterDocs
|
|
filterDocs := []string{askrequest.Input}
|
|
// 获取嵌入向量
|
|
resList, err := GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, filterDocs)
|
|
if err != nil {
|
|
return res, fmt.Errorf("the embeddings get error:%v", err)
|
|
}
|
|
resk, err := model.AskDocument(askrequest.ID, resList[0])
|
|
if err != nil {
|
|
return res, fmt.Errorf("the ask document error:%v", err)
|
|
}
|
|
msg := ""
|
|
for _, item := range resk {
|
|
msg += fmt.Sprintf("- %s\n", item.Content)
|
|
res = append(res, types.AskDocResponse{Content: item.Content, Score: item.Score, FilePath: item.FilePath, FileName: item.FileName})
|
|
}
|
|
prompt := fmt.Sprintf(`从文档\n\"\"\"\n%s\n\"\"\"\n中找问题\n\"\"\"\n%s\n\"\"\"\n的答案,找到答案就使用文档语句回答问题,找不到答案就用自身知识回答并且告诉用户该信息不是来自文档。\n不要复述问题,直接开始回答。`, msg, lastMessage)
|
|
req.Messages = append([]types.Message{}, types.Message{Role: "user", Content: prompt})
|
|
return res, nil
|
|
}
|
|
func ChatWithWeb(req *types.ChatRequest) ([]search.Entity, error) {
|
|
lastMessage, err := GetLastMessage(*req)
|
|
var searchRequest []search.Entity
|
|
if err != nil {
|
|
return searchRequest, err
|
|
}
|
|
searchRequest = search.SearchWeb(lastMessage)
|
|
if len(searchRequest) == 0 {
|
|
return searchRequest, fmt.Errorf("the search web is empty")
|
|
}
|
|
var inputPrompt string
|
|
for _, search := range searchRequest {
|
|
inputPrompt += fmt.Sprintf("- 标题: %s\n- 内容: %s\n", search.Title, search.Content)
|
|
}
|
|
currentDate := time.Now().Format("2006-01-02")
|
|
searchPrompt := fmt.Sprintf(`
|
|
# 以下是来自互联网的信息:
|
|
%s
|
|
|
|
# 当前日期: %s
|
|
|
|
# 要求:
|
|
根据最新发布的信息回答用户问题。
|
|
|
|
# 用户问题:%s
|
|
|
|
`, inputPrompt, currentDate, lastMessage)
|
|
//log.Printf("the search web is %v", searchPrompt)
|
|
// req.Messages = append([]Message{}, Message{Role: "assistant", Content: searchPrompt})
|
|
req.Messages = append([]types.Message{}, types.Message{Role: "user", Content: searchPrompt})
|
|
return searchRequest, nil
|
|
}
|
|
func EmbeddingHandler(w http.ResponseWriter, r *http.Request) {
|
|
var req types.ChatRequest
|
|
err := json.NewDecoder(r.Body).Decode(&req)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "the chat request error:"+err.Error())
|
|
return
|
|
}
|
|
headers, url, err := GetHeadersAndUrl(req, "embeddings")
|
|
if err != nil {
|
|
libs.ErrorMsg(w, err.Error())
|
|
return
|
|
}
|
|
ForwardHandler(w, r, req, url, headers, "POST")
|
|
}
|
|
func GetLastMessage(req types.ChatRequest) (string, error) {
|
|
if len(req.Messages) == 0 {
|
|
return "", fmt.Errorf("the messages is empty")
|
|
}
|
|
lastMessage := req.Messages[len(req.Messages)-1]
|
|
if lastMessage.Role != "user" {
|
|
return "", fmt.Errorf("the last message is not user")
|
|
}
|
|
return lastMessage.Content, nil
|
|
}
|
|
|