diff --git a/frontend/components.d.ts b/frontend/components.d.ts index b532055..737355f 100644 --- a/frontend/components.d.ts +++ b/frontend/components.d.ts @@ -66,37 +66,24 @@ declare module 'vue' { EditType: typeof import('./src/components/builtin/EditType.vue')['default'] ElAside: typeof import('element-plus/es')['ElAside'] ElAvatar: typeof import('element-plus/es')['ElAvatar'] - ElB: typeof import('element-plus/es')['ElB'] ElBadge: typeof import('element-plus/es')['ElBadge'] ElButton: typeof import('element-plus/es')['ElButton'] ElCard: typeof import('element-plus/es')['ElCard'] ElCarousel: typeof import('element-plus/es')['ElCarousel'] ElCarouselItem: typeof import('element-plus/es')['ElCarouselItem'] - ElCheckbox: typeof import('element-plus/es')['ElCheckbox'] ElCol: typeof import('element-plus/es')['ElCol'] ElCollapse: typeof import('element-plus/es')['ElCollapse'] ElCollapseItem: typeof import('element-plus/es')['ElCollapseItem'] - ElColorPicker: typeof import('element-plus/es')['ElColorPicker'] ElContainer: typeof import('element-plus/es')['ElContainer'] ElDialog: typeof import('element-plus/es')['ElDialog'] - ElDrawer: typeof import('element-plus/es')['ElDrawer'] - ElDropdown: typeof import('element-plus/es')['ElDropdown'] - ElDropdownItem: typeof import('element-plus/es')['ElDropdownItem'] - ElDropdownMenu: typeof import('element-plus/es')['ElDropdownMenu'] ElEmpty: typeof import('element-plus/es')['ElEmpty'] - ElFooter: typeof import('element-plus/es')['ElFooter'] ElForm: typeof import('element-plus/es')['ElForm'] ElFormItem: typeof import('element-plus/es')['ElFormItem'] ElHeader: typeof import('element-plus/es')['ElHeader'] ElIcon: typeof import('element-plus/es')['ElIcon'] ElImage: typeof import('element-plus/es')['ElImage'] ElInput: typeof import('element-plus/es')['ElInput'] - ElItem: typeof import('element-plus/es')['ElItem'] - ElMain: typeof import('element-plus/es')['ElMain'] - ElMenu: typeof import('element-plus/es')['ElMenu'] - ElMenuItem: typeof import('element-plus/es')['ElMenuItem'] ElOption: typeof import('element-plus/es')['ElOption'] - ElPageHeader: typeof import('element-plus/es')['ElPageHeader'] ElPagination: typeof import('element-plus/es')['ElPagination'] ElPopover: typeof import('element-plus/es')['ElPopover'] ElProgress: typeof import('element-plus/es')['ElProgress'] @@ -104,17 +91,10 @@ declare module 'vue' { ElScrollbar: typeof import('element-plus/es')['ElScrollbar'] ElSelect: typeof import('element-plus/es')['ElSelect'] ElSlider: typeof import('element-plus/es')['ElSlider'] - ElSpace: typeof import('element-plus/es')['ElSpace'] - ElSwitch: typeof import('element-plus/es')['ElSwitch'] - ElTable: typeof import('element-plus/es')['ElTable'] - ElTableColumn: typeof import('element-plus/es')['ElTableColumn'] ElTabPane: typeof import('element-plus/es')['ElTabPane'] ElTabs: typeof import('element-plus/es')['ElTabs'] - ElTag: typeof import('element-plus/es')['ElTag'] ElText: typeof import('element-plus/es')['ElText'] ElTooltip: typeof import('element-plus/es')['ElTooltip'] - ElTransfer: typeof import('element-plus/es')['ElTransfer'] - ElTree: typeof import('element-plus/es')['ElTree'] Error: typeof import('./src/components/taskbar/Error.vue')['default'] FileIcon: typeof import('./src/components/builtin/FileIcon.vue')['default'] FileIconImg: typeof import('./src/components/builtin/FileIconImg.vue')['default'] diff --git a/frontend/src/components/localchat/AiChatMain.vue b/frontend/src/components/localchat/AiChatMain.vue index e9f9ef7..72cbd50 100644 --- a/frontend/src/components/localchat/AiChatMain.vue +++ b/frontend/src/components/localchat/AiChatMain.vue @@ -6,12 +6,15 @@ import { notifyError } from "@/util/msg.ts"; import { ElScrollbar } from "element-plus"; import { getSystemConfig } from "@/system/config"; import { Vue3Lottie } from "vue3-lottie"; +import { file } from "jszip"; const chatStore = useAiChatStore(); const modelStore = useModelStore(); const isPadding = ref(false); //是否发送中 const webSearch = ref(false); const imageInput: any = ref(null); let imageData = ref(""); +let fileContent = ref(""); +let fileName = ref(""); const messageContainerRef = ref>(); const messageInnerRef = ref(); // User Input Message @@ -86,6 +89,8 @@ const createCompletion = async () => { engine: chatStore.chatInfo.engine, stream: false, webSearch: webSearch.value, + fileContent: fileContent.value, + fileName: fileName.value, options: chatConfig, }; if (imageData.value != "") { @@ -111,10 +116,12 @@ const createCompletion = async () => { method: "POST", body: JSON.stringify(postMsg), }; - + //console.log(postData) const completion = await fetch(config.apiUrl + '/ai/chat', postData); //const completion:any = await modelStore.getModel(postData) imageData.value = ""; + fileContent.value = ""; + fileName.value = ""; if (!completion.ok) { const errorData = await completion.json(); notifyError(errorData.error.message); @@ -167,12 +174,6 @@ const handleKeydown = (e: any) => { } }; const selectImage = async () => { - const img2txtModel = await modelStore.getModel("img2txt"); - if (!img2txtModel) { - notifyError(t("aichat.notEyeModel")); - return; - } - imageInput.value.click(); }; const uploadImage = async (event: any) => { @@ -180,9 +181,24 @@ const uploadImage = async (event: any) => { if (!file) { return; } + //console.log(file) + if (file.type.startsWith('image/')) { + const img2txtModel = await modelStore.getModel("img2txt"); + if (!img2txtModel) { + notifyError(t("aichat.notEyeModel")); + return; + } + } const reader = new FileReader(); reader.onload = (e: any) => { - imageData.value = e.target.result.split(",")[1]; + const fileData = e.target.result.split(",")[1]; + if (file.type.startsWith('image/')) { + imageData.value = fileData; + } else { + fileContent.value = fileData; + fileName.value = file.name; + } + //console.log(fileContent.value) }; reader.readAsDataURL(file); @@ -214,15 +230,19 @@ const uploadImage = async (event: any) => {
- - + + - + + @keydown="handleKeydown" autofocus class="ai-input-area" /> diff --git a/godo/.gitignore b/godo/.gitignore index 854a55b..5ae554c 100644 --- a/godo/.gitignore +++ b/godo/.gitignore @@ -2,3 +2,4 @@ tmp deps/dist deps/*.zip godo +testdata diff --git a/godo/ai/server/chat.go b/godo/ai/server/chat.go index 547aca5..ae4ad31 100644 --- a/godo/ai/server/chat.go +++ b/godo/ai/server/chat.go @@ -5,6 +5,8 @@ import ( "fmt" "godo/ai/search" "godo/libs" + "godo/office" + "log" "net/http" "time" ) @@ -15,6 +17,7 @@ type ChatRequest struct { Stream bool `json:"stream"` WebSearch bool `json:"webSearch"` FileContent string `json:"fileContent"` + FileName string `json:"fileName"` Options map[string]interface{} `json:"options"` Messages []Message `json:"messages"` } @@ -40,6 +43,13 @@ func ChatHandler(w http.ResponseWriter, r *http.Request) { return } } + if req.FileContent != "" { + err = ChatWithFile(&req) + if err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + } headers, url, err := GetHeadersAndUrl(req, "chat") // log.Printf("url: %s", url) // log.Printf("headers: %v", headers) @@ -49,15 +59,26 @@ func ChatHandler(w http.ResponseWriter, r *http.Request) { } ForwardHandler(w, r, req, url, headers, "POST") } -func ChatWithWeb(req *ChatRequest) error { - if len(req.Messages) == 0 { - return fmt.Errorf("the messages is empty") +func ChatWithFile(req *ChatRequest) error { + fileContent, err := office.ProcessBase64File(req.FileContent, req.FileName) + if err != nil { + return err } - lastMessage := req.Messages[len(req.Messages)-1] - if lastMessage.Role != "user" { - return fmt.Errorf("the last message is not user") + lastMessage, err := GetLastMessage(*req) + if err != nil { + return err } - searchRequest := search.SearchWeb(lastMessage.Content) + userQuestion := fmt.Sprintf("请对\n%s\n的内容进行分析,给出对用户输入的回答: %s", fileContent, lastMessage) + log.Printf("the search file is %v", userQuestion) + req.Messages = append(req.Messages, Message{Role: "user", Content: userQuestion}) + return nil +} +func ChatWithWeb(req *ChatRequest) error { + lastMessage, err := GetLastMessage(*req) + if err != nil { + return err + } + searchRequest := search.SearchWeb(lastMessage) if len(searchRequest) == 0 { return fmt.Errorf("the search web is empty") } @@ -77,9 +98,8 @@ func ChatWithWeb(req *ChatRequest) error { # 用户问题:%s -`, inputPrompt, currentDate, lastMessage.Content) +`, inputPrompt, currentDate, lastMessage) //log.Printf("the search web is %v", searchPrompt) - // userQuestion := fmt.Sprintf("问:%s,答:", lastMessage.Content) // req.Messages = append([]Message{}, Message{Role: "assistant", Content: searchPrompt}) req.Messages = append([]Message{}, Message{Role: "user", Content: searchPrompt}) return nil @@ -98,3 +118,13 @@ func EmbeddingHandler(w http.ResponseWriter, r *http.Request) { } ForwardHandler(w, r, req, url, headers, "POST") } +func GetLastMessage(req ChatRequest) (string, error) { + if len(req.Messages) == 0 { + return "", fmt.Errorf("the messages is empty") + } + lastMessage := req.Messages[len(req.Messages)-1] + if lastMessage.Role != "user" { + return "", fmt.Errorf("the last message is not user") + } + return lastMessage.Content, nil +} diff --git a/godo/office/document.go b/godo/office/document.go index d8cb148..bf76fcb 100644 --- a/godo/office/document.go +++ b/godo/office/document.go @@ -20,11 +20,13 @@ package office import ( "archive/zip" "bufio" + "encoding/base64" "encoding/json" "encoding/xml" "errors" "fmt" "godo/libs" + "log" "os" "path" "path/filepath" @@ -113,6 +115,49 @@ func ProcessFile(filePath string) DocResult { return DocResult{filePath: filePath, newFilePath: newFilePath, err: nil} } + +// ProcessBase64File 处理解码后的文件并提取文本信息 +func ProcessBase64File(base64String string, fileName string) (string, error) { + // 解码 Base64 字符串 + decodedBytes, err := base64.StdEncoding.DecodeString(base64String) + if err != nil { + return "", fmt.Errorf("failed to decode base64 string: %v", err) + } + // 获取文件后缀并转换为小写 + // fileExt := strings.ToLower(filepath.Ext(fileName)) + // if fileExt == "" { + // return "", fmt.Errorf("file extension not found in filename: %s", fileName) + // } + // log.Printf("File type: %s\n", fileExt) + cacheDir := libs.GetCacheDir() + tempFilePath := filepath.Join(cacheDir, fileName) + + // 创建临时文件 + tempFile, err := os.Create(tempFilePath) + if err != nil { + return "", fmt.Errorf("failed to create temp file: %v", err) + } + defer tempFile.Close() + + // 将解码后的数据写入临时文件 + _, err = tempFile.Write(decodedBytes) + if err != nil { + return "", fmt.Errorf("failed to write to temp file: %v", err) + } + + // 获取文档内容 + doc, err := GetDocument(tempFilePath) + if err != nil { + return "", fmt.Errorf("failed to get document: %v", err) + } + log.Printf("Document content: %s\n", doc.Content) + + // 删除临时文件 + defer os.Remove(tempFilePath) + + // 提取文本内容 + return doc.Content, nil +} func GetDocument(pathname string) (*Document, error) { if !libs.PathExists(pathname) { return nil, fmt.Errorf("file does not exist: %s", pathname) diff --git a/godo/office/document_test.go b/godo/office/document_test.go new file mode 100644 index 0000000..4f86228 --- /dev/null +++ b/godo/office/document_test.go @@ -0,0 +1,38 @@ +package office + +import ( + "log" + "os" + "path/filepath" + "testing" +) + +func TestGetDocument(t *testing.T) { + // Get the absolute path to the testdata directory + testdataDir, err := filepath.Abs("testdata") + if err != nil { + t.Fatalf("Failed to get absolute path to testdata directory: %v", err) + } + + // Read all files in the testdata directory + files, err := os.ReadDir(testdataDir) + if err != nil { + t.Fatalf("Failed to read testdata directory: %v", err) + } + + // Iterate over each file and test GetDocument + for _, file := range files { + if !file.IsDir() { + filePath := filepath.Join(testdataDir, file.Name()) + t.Run(file.Name(), func(t *testing.T) { + doc, err := GetDocument(filePath) + if err != nil { + t.Errorf("Failed to get document for %s: %v", file.Name(), err) + } else { + log.Printf("Document file.Name: %s\ncontent: %d\n", file.Name(), len(doc.Content)) + //t.Logf("Document file.Name: %s\n", file.Name()) + } + }) + } + } +}