From 63a90f38e03a40b827504ee65f0b79377a4eee25 Mon Sep 17 00:00:00 2001 From: godo Date: Thu, 9 Jan 2025 17:53:29 +0800 Subject: [PATCH] add vector --- frontend/components.d.ts | 4 + frontend/src/components/builtin/FileList.vue | 2 +- godo/ai/server/embedding.go | 2 +- godo/ai/vector/files.go | 444 ++++++++++--- godo/ai/vector/vector.go | 642 ++++++++----------- godo/cmd/main.go | 3 + godo/files/fs.go | 52 +- godo/files/os.go | 46 +- godo/model/init.go | 8 + godo/model/vec_doc.go | 138 +++- godo/model/vec_list.go | 3 +- 11 files changed, 783 insertions(+), 561 deletions(-) diff --git a/frontend/components.d.ts b/frontend/components.d.ts index 18f8f32..3ccca52 100644 --- a/frontend/components.d.ts +++ b/frontend/components.d.ts @@ -75,9 +75,13 @@ declare module 'vue' { ElFormItem: typeof import('element-plus/es')['ElFormItem'] ElIcon: typeof import('element-plus/es')['ElIcon'] ElInput: typeof import('element-plus/es')['ElInput'] + ElOption: typeof import('element-plus/es')['ElOption'] ElPagination: typeof import('element-plus/es')['ElPagination'] ElProgress: typeof import('element-plus/es')['ElProgress'] ElRow: typeof import('element-plus/es')['ElRow'] + ElSelect: typeof import('element-plus/es')['ElSelect'] + ElSwitch: typeof import('element-plus/es')['ElSwitch'] + ElTag: typeof import('element-plus/es')['ElTag'] Error: typeof import('./src/components/taskbar/Error.vue')['default'] FileIcon: typeof import('./src/components/builtin/FileIcon.vue')['default'] FileIconImg: typeof import('./src/components/builtin/FileIconImg.vue')['default'] diff --git a/frontend/src/components/builtin/FileList.vue b/frontend/src/components/builtin/FileList.vue index f9fe5ec..f4ab658 100644 --- a/frontend/src/components/builtin/FileList.vue +++ b/frontend/src/components/builtin/FileList.vue @@ -347,7 +347,7 @@ function handleRightClick( menuArr.push({ label: "加入知识库", click: () => { - //console.log(item.path) + console.log(item) addKnowledge(item.path).then((res:any) => { console.log(res) if(res.code != 0){ diff --git a/godo/ai/server/embedding.go b/godo/ai/server/embedding.go index d396dcc..cf0215e 100644 --- a/godo/ai/server/embedding.go +++ b/godo/ai/server/embedding.go @@ -65,7 +65,7 @@ func getOllamaEmbedding(model string, text []string) ([][]float32, error) { if err != nil { return nil, fmt.Errorf("couldn't unmarshal response body: %w", err) } - log.Printf("Embedding: %v", embeddingResponse.Embeddings) + //log.Printf("Embedding: %v", embeddingResponse.Embeddings) // Return the embeddings directly. if len(embeddingResponse.Embeddings) == 0 { diff --git a/godo/ai/vector/files.go b/godo/ai/vector/files.go index 05c42c2..9e954b1 100644 --- a/godo/ai/vector/files.go +++ b/godo/ai/vector/files.go @@ -1,91 +1,337 @@ package vector import ( - "encoding/json" "fmt" - "godo/ai/server" "godo/libs" - "godo/office" "log" "os" "path/filepath" "strings" + "sync" + "time" "github.com/fsnotify/fsnotify" ) -var MapFilePathMonitors = map[string]uint{} +var ( + MapFilePathMonitors = map[string]uint{} + watcher *fsnotify.Watcher + fileQueue = make(chan string, 100) // 队列大小可以根据需要调整 + numWorkers = 3 // 工作协程的数量 + wg sync.WaitGroup + syncingKnowledgeIds = make(map[uint]syncingStats) // 记录正在同步的 knowledgeId 及其同步状态 + syncMutex sync.Mutex // 保护 syncingKnowledgeIds 的互斥锁 + renameMap = make(map[string]string) // 临时映射存储 Remove 事件的路径 + renameMutex sync.Mutex // 保护 renameMap 的互斥锁 + watcherMutex sync.Mutex // 保护 watcher 的互斥锁 +) -func FolderMonitor() { - basePath, err := libs.GetOsDir() +type syncingStats struct { + totalFiles int + processedFiles int +} + +func InitMonitor() { + var err error + watcherMutex.Lock() + watcher, err = fsnotify.NewWatcher() if err != nil { - log.Printf("Error getting base path: %s", err.Error()) - return + log.Fatalf("Error creating watcher: %s", err.Error()) } - watcher, err := fsnotify.NewWatcher() - if err != nil { - log.Printf("Error creating watcher: %s", err.Error()) - return - } - defer watcher.Close() + watcherMutex.Unlock() + go FolderMonitor() + go startWatching() - // 递归添加所有子目录 - addRecursive(basePath, watcher) + // 启动 worker + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go worker() + } +} - // Start listening for events. - go func() { - for { - select { - case event, ok := <-watcher.Events: - if !ok { - log.Println("error:", err) - return +func startWatching() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + log.Println("error event") + return + } + filePath := filepath.Clean(event.Name) + result, exists := shouldProcess(filePath) + if result > 0 { + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + log.Printf("Event: %v, File: %s", event.Op, filePath) + if isFileComplete(filePath) { + // 将文件路径放入队列 + fileQueue <- filePath + } } - //log.Println("event:", event) - filePath := event.Name - result, knowledgeId := shouldProcess(filePath) - //log.Printf("result:%d,knowledgeId:%d", result, knowledgeId) - if result > 0 { - info, err := os.Stat(filePath) - if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - log.Println("modified file:", filePath) - if !info.IsDir() { - handleGodoosFile(filePath, knowledgeId) + if event.Has(fsnotify.Create) { + if info, err := os.Stat(filePath); err == nil && info.IsDir() { + addRecursive(filePath, watcher) + } + // 检查是否是重命名事件 + handleRenameCreateEvent(event) + } + if event.Has(fsnotify.Remove) { + //log.Printf("Event: %v, File: %s,exists:%d", event.Op, filePath, exists) + isDir := true + newFileName := fmt.Sprintf(".godoos.%d.%s.json", result, filepath.Base(filePath)) + newFilePath := filepath.Join(filepath.Dir(filePath), newFileName) + if libs.PathExists(newFilePath) { + isDir = false + } + if isDir { + watcherMutex.Lock() + if watcher != nil { + watcher.Remove(filePath) } + watcherMutex.Unlock() } - if event.Has(fsnotify.Create) || event.Has(fsnotify.Rename) { - // 处理创建或重命名事件,添加新目录 - if err == nil && info.IsDir() { - addRecursive(filePath, watcher) + if exists == 1 { + err := DeleteVector(result) + if err != nil { + log.Printf("Error deleting vector %d: %v", result, err) } } - if event.Has(fsnotify.Remove) { - // 处理删除事件,移除目录 - if err == nil && info.IsDir() { - watcher.Remove(filePath) + if exists == 2 && !isDir { + err := DeleteVectorFile(result, filePath) + if err != nil { + log.Printf("Error deleting vector file %d: %v", result, err) } } + //存储 Remove 事件的路径 + handleRenameRemoveEvent(event) + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Println("error:", err) + } + } +} + +func handleRenameRemoveEvent(event fsnotify.Event) { + renameMutex.Lock() + defer renameMutex.Unlock() + //log.Printf("handleRenameRemoveEvent: %v, File: %s", event.Op, event.Name) + renameMap[event.Name] = event.Name +} + +func handleRenameCreateEvent(event fsnotify.Event) { + renameMutex.Lock() + defer renameMutex.Unlock() + //log.Printf("handleRenameCreateEvent: %v, File: %s", event.Op, event.Name) + // 规范化路径 + newPath := filepath.Clean(event.Name) + + // 检查是否是重命名事件 + for oldPath := range renameMap { + if oldPath != "" { + // 找到对应的 Remove 事件 + oldPathClean := filepath.Clean(oldPath) + if oldPathClean == newPath { + //log.Printf("File renamed from %s to %s", oldPath, newPath) + + // 更新 MapFilePathMonitors + for path, id := range MapFilePathMonitors { + if path == oldPathClean { + delete(MapFilePathMonitors, path) + MapFilePathMonitors[newPath] = id + log.Printf("Updated MapFilePathMonitors: %s -> %s", oldPathClean, newPath) + break + } } - case err, ok := <-watcher.Errors: - if !ok { - return + + // 更新 watcher + watcherMutex.Lock() + if watcher != nil { + if err := watcher.Remove(oldPathClean); err != nil { + log.Printf("Error removing old path %s from watcher: %v", oldPathClean, err) + } + if err := watcher.Add(newPath); err != nil { + log.Printf("Error adding new path %s to watcher: %v", newPath, err) + } } - log.Println("error:", err) + watcherMutex.Unlock() + + // 如果是目录,递归更新子目录 + if info, err := os.Stat(newPath); err == nil && info.IsDir() { + addRecursive(newPath, watcher) + } + + // 清除临时映射中的路径 + delete(renameMap, oldPath) + break } } - }() + } +} - // Add a path. - err = watcher.Add(basePath) +func worker() { + defer wg.Done() + for filePath := range fileQueue { + knowledgeId, exists := shouldProcess(filePath) + if exists == 0 { + log.Printf("File path %s is not being monitored", filePath) + continue + } + + // 更新已处理文件数 + syncMutex.Lock() + if stats, ok := syncingKnowledgeIds[knowledgeId]; ok { + stats.processedFiles++ + syncingKnowledgeIds[knowledgeId] = stats + } + syncMutex.Unlock() + + err := handleGodoosFile(filePath, knowledgeId) + if err != nil { + log.Printf("Error handling file %s: %v", filePath, err) + } + } +} + +func FolderMonitor() { + basePath, err := libs.GetOsDir() if err != nil { - log.Fatal(err) + log.Printf("Error getting base path: %s", err.Error()) + return } + // 递归添加所有子目录 + addRecursive(basePath, watcher) + + // Add a path. + watcherMutex.Lock() + if watcher != nil { + err = watcher.Add(basePath) + if err != nil { + log.Fatal(err) + } + } + watcherMutex.Unlock() + // Block main goroutine forever. <-make(chan struct{}) } -func shouldProcess(filePath string) (int, uint) { +func AddWatchFolder(folderPath string, knowledgeId uint, callback func()) error { + if watcher == nil { + InitMonitor() + } + // 规范化路径 + folderPath = filepath.Clean(folderPath) + + // 检查文件夹是否存在 + if !libs.PathExists(folderPath) { + return fmt.Errorf("folder path does not exist: %s", folderPath) + } + + // 检查文件夹是否已经存在于监视器中 + if _, exists := MapFilePathMonitors[folderPath]; exists { + return fmt.Errorf("folder path is already being monitored: %s", folderPath) + } + + // 递归添加所有子目录 + addRecursive(folderPath, watcher) + + // 计算总文件数 + totalFiles, err := countFiles(folderPath) + if err != nil { + return fmt.Errorf("failed to count files in folder path: %w", err) + } + + // 更新 syncingKnowledgeIds + syncMutex.Lock() + syncingKnowledgeIds[knowledgeId] = syncingStats{ + totalFiles: totalFiles, + processedFiles: 0, + } + syncMutex.Unlock() + + // 更新 MapFilePathMonitors + MapFilePathMonitors[folderPath] = knowledgeId + + // 添加文件夹路径到监视器 + err = watcher.Add(folderPath) + if err != nil { + return fmt.Errorf("failed to add folder path to watcher: %w", err) + } + + // 调用回调函数 + if callback != nil { + callback() + } + + log.Printf("Added folder path %s to watcher with knowledgeId %d", folderPath, knowledgeId) + return nil +} + +// RemoveWatchFolder 根据路径删除观察文件夹 +func RemoveWatchFolder(folderPath string) error { + // 规范化路径 + folderPath = filepath.Clean(folderPath) + + // 检查文件夹是否存在于监视器中 + knowledgeId, exists := MapFilePathMonitors[folderPath] + if !exists { + return fmt.Errorf("folder path is not being monitored: %s", folderPath) + } + + // 从 watcher 中移除路径 + watcherMutex.Lock() + if watcher != nil { + err := watcher.Remove(folderPath) + if err != nil { + return fmt.Errorf("failed to remove folder path from watcher: %w", err) + } + } + watcherMutex.Unlock() + + // 递归移除所有子目录 + err := filepath.Walk(folderPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + log.Printf("Error walking path %s: %v", path, err) + return err + } + if info.IsDir() { + result, _ := shouldProcess(path) + if result > 0 { + // 从 watcher 中移除路径 + watcherMutex.Lock() + if watcher != nil { + err := watcher.Remove(path) + if err != nil { + log.Printf("Error removing path %s from watcher: %v", path, err) + return err + } + } + watcherMutex.Unlock() + } + } + return nil + }) + if err != nil { + return fmt.Errorf("failed to remove folder path from watcher: %w", err) + } + + // 从 MapFilePathMonitors 中删除条目 + delete(MapFilePathMonitors, folderPath) + + // 从 syncingKnowledgeIds 中删除条目 + syncMutex.Lock() + delete(syncingKnowledgeIds, knowledgeId) + syncMutex.Unlock() + + log.Printf("Removed folder path %s from watcher with knowledgeId %d", folderPath, knowledgeId) + return nil +} + +func shouldProcess(filePath string) (uint, int) { // 规范化路径 filePath = filepath.Clean(filePath) @@ -96,10 +342,10 @@ func shouldProcess(filePath string) (int, uint) { } path = filepath.Clean(path) if filePath == path { - return 1, id // 完全相等 + return id, 1 // 完全相等 } if strings.HasPrefix(filePath, path+string(filepath.Separator)) { - return 2, id // 包含 + return id, 2 // 包含 } } return 0, 0 // 不存在 @@ -120,7 +366,6 @@ func addRecursive(path string, watcher *fsnotify.Watcher) { } log.Printf("Added path %s to watcher", path) } - } return nil }) @@ -129,53 +374,62 @@ func addRecursive(path string, watcher *fsnotify.Watcher) { } } -func handleGodoosFile(filePath string, knowledgeId uint) error { - log.Printf("========Handling .godoos file: %s", filePath) - baseName := filepath.Base(filePath) - if baseName[:8] != ".godoos." { - if baseName[:1] != "." { - office.ProcessFile(filePath, knowledgeId) +// countFiles 递归计算文件夹中的文件数 +func countFiles(folderPath string) (int, error) { + var count int + err := filepath.Walk(folderPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + count++ } return nil - } - var doc office.Document - content, err := os.ReadFile(filePath) + }) if err != nil { - return err + return 0, err } - err = json.Unmarshal(content, &doc) - if err != nil { - return err + return count, nil +} + +// GetSyncPercentage 计算并返回同步百分比 +func GetSyncPercentage(knowledgeId uint) float64 { + syncMutex.Lock() + defer syncMutex.Unlock() + if stats, ok := syncingKnowledgeIds[knowledgeId]; ok { + if stats.totalFiles == 0 { + return 0.0 + } + return float64(stats.processedFiles) / float64(stats.totalFiles) * 100 } - if len(doc.Split) == 0 { - return fmt.Errorf("invalid .godoos file: %s", filePath) + return 0.0 +} + +// isFileComplete 检查文件是否已经完全创建 +func isFileComplete(filePath string) bool { + // 等待一段时间确保文件已经完全创建 + time.Sleep(100 * time.Millisecond) + + // 检查文件是否存在 + if _, err := os.Stat(filePath); err != nil { + log.Printf("File %s does not exist: %v", filePath, err) + return false } - knowData, err := GetVector(knowledgeId) + + // 检查文件大小是否达到预期 + fileInfo, err := os.Stat(filePath) if err != nil { - return err + log.Printf("Error stat file %s: %v", filePath, err) + return false } - resList, err := server.GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, doc.Split) - if err != nil { - return err - } - if len(resList) != len(doc.Split) { - return fmt.Errorf("invalid file len: %s, expected %d embeddings, got %d", filePath, len(doc.Split), len(resList)) - } - // var vectordocs []model.Vectordoc - // for i, res := range resList { - // //log.Printf("res: %v", res) - // vectordoc := model.Vectordoc{ - // Content: doc.Split[i], - // Embed: res, - // FilePath: filePath, - // KnowledgeID: knowledgeId, - // Pos: fmt.Sprintf("%d", i), - // } - // vectordocs = append(vectordocs, vectordoc) - // } - // result := vectorListDb.Create(&vectordocs) - // if result.Error != nil { - // return result.Error - // } - return nil + // 例如,检查文件大小是否大于某个阈值 + if fileInfo.Size() == 0 { + log.Printf("File %s is empty", filePath) + return false + } + if fileInfo.IsDir() { + log.Printf("File %s is a directory", filePath) + return false + } + return true } diff --git a/godo/ai/vector/vector.go b/godo/ai/vector/vector.go index 47b9fcd..38034c9 100644 --- a/godo/ai/vector/vector.go +++ b/godo/ai/vector/vector.go @@ -3,13 +3,42 @@ package vector import ( "encoding/json" "fmt" + "godo/ai/server" "godo/libs" "godo/model" "godo/office" + "log" "net/http" + "os" "path/filepath" + "strconv" + "strings" ) +func InitMonitorVector() { + // 确保数据库连接正常 + // 确保数据库连接正常 + db, err := model.Db.DB() + if err != nil { + log.Fatalf("Failed to get database connection: %v", err) + } + if err = db.Ping(); err != nil { + log.Fatalf("Failed to connect to database: %v", err) + } + list, err := GetVectorList() + if err != nil { + fmt.Println("GetVectorList error:", err) + return + } + if len(list) == 0 { + log.Println("no vector db found, creating a new one") + } + //log.Printf("init monitor:%v", list) + for _, v := range list { + MapFilePathMonitors[v.FilePath] = v.ID + } + go InitMonitor() +} func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { var req model.VecList err := json.NewDecoder(r.Body).Decode(&req) @@ -27,7 +56,24 @@ func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { return } req.FilePath = filepath.Join(basePath, req.FilePath) - + if !libs.PathExists(req.FilePath) { + libs.ErrorMsg(w, "the knowledge path is not exists") + return + } + fileInfo, err := os.Stat(req.FilePath) + if err != nil { + libs.ErrorMsg(w, "get vector db path error:"+err.Error()) + return + } + if !fileInfo.IsDir() { + libs.ErrorMsg(w, "the knowledge path is not dir") + return + } + knowledgeFilePath := filepath.Join(req.FilePath, ".knowledge") + if libs.PathExists(knowledgeFilePath) { + libs.ErrorMsg(w, "the knowledgeId already exists") + return + } id, err := CreateVector(req) if err != nil { libs.ErrorMsg(w, err.Error()) @@ -35,6 +81,62 @@ func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { } libs.SuccessMsg(w, id, "create vector success") } +func HandlerAskKnowledge(w http.ResponseWriter, r *http.Request) { + var req struct { + ID uint `json:"id"` + Input string `json:"input"` + } + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + libs.ErrorMsg(w, "the chat request error:"+err.Error()) + return + } + if req.ID == 0 { + libs.ErrorMsg(w, "knowledgeId is empty") + return + } + var knowData model.VecList + if err := model.Db.First(&knowData, req.ID).Error; err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + var filterDocs []string + filterDocs = append(filterDocs, req.Input) + // 获取嵌入向量 + resList, err := server.GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, filterDocs) + if err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + res, err := model.AskDocument(req.ID, resList[0]) + if err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + libs.SuccessMsg(w, res, "ask knowledge success") + +} +func HandlerDelKnowledge(w http.ResponseWriter, r *http.Request) { + idStr := r.URL.Query().Get("id") + if idStr == "" { + libs.ErrorMsg(w, "knowledgeId is empty") + return + } + id, err := strconv.Atoi(idStr) + if err != nil { + libs.ErrorMsg(w, "knowledgeId is not number") + return + } + if id == 0 { + libs.ErrorMsg(w, "knowledgeId is not number") + return + } + if err := DeleteVector(uint(id)); err != nil { + libs.ErrorMsg(w, err.Error()) + return + } + libs.SuccessMsg(w, nil, "delete knowledge success") +} // CreateVector 创建一个新的 VectorList 记录 func CreateVector(data model.VecList) (uint, error) { @@ -59,18 +161,72 @@ func CreateVector(data model.VecList) (uint, error) { if result.Error != nil { return 0, fmt.Errorf("failed to create vector list: %w", result.Error) } + // 创建 .knowledge 文件并写入 knowledgeId + knowledgeFilePath := filepath.Join(data.FilePath, ".knowledge") + err := os.WriteFile(knowledgeFilePath, []byte(fmt.Sprintf("%d", data.ID)), 0644) + if err != nil { + return 0, fmt.Errorf("failed to write knowledgeId to .knowledge file: %w", err) + } - // Start background tasks - go office.SetDocument(data.FilePath, uint(data.ID)) + // // 等待 AddWatchFolder 完成 + go AddWatchFolder(data.FilePath, data.ID, func() { + go office.SetDocument(data.FilePath, data.ID) + }) - return uint(data.ID), nil + return data.ID, nil } // DeleteVector 删除指定id的 VectorList 记录 -func DeleteVector(id int) error { +func DeleteVector(id uint) error { + var vectorList model.VecList + if err := model.Db.First(&vectorList, id).Error; err != nil { + return fmt.Errorf("failed to find vector list: %w", err) + } + // Delete .knowledge file + knowledgeFilePath := filepath.Join(vectorList.FilePath, ".knowledge") + if err := os.Remove(knowledgeFilePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to delete .knowledge file: %w", err) + } + // Delete all .godoos. files in the directory and its subdirectories + if err := deleteGodoosFiles(vectorList.FilePath); err != nil { + return fmt.Errorf("failed to delete .godoos. files: %w", err) + } + //delete(MapFilePathMonitors, vectorList.FilePath) + RemoveWatchFolder(vectorList.FilePath) + return model.Db.Delete(&model.VecList{}, id).Error +} +func DeleteVectorFile(id uint, filePath string) error { + var vectorList model.VecList + if err := model.Db.First(&vectorList, id).Error; err != nil { + return fmt.Errorf("failed to find vector list: %w", err) + } + // Delete file in database + if err := model.Deletedocument(id, filePath); err != nil { + return fmt.Errorf("failed to delete .godoos. files: %w", err) + } + newFileName := fmt.Sprintf(".godoos.%d.%s.json", id, filepath.Base(filePath)) + newFilePath := filepath.Join(filepath.Dir(filePath), newFileName) + if libs.PathExists(newFilePath) { + if err := os.Remove(newFilePath); err != nil { + return fmt.Errorf("failed to delete .godoos. files: %w", err) + } + } return model.Db.Delete(&model.VecList{}, id).Error } +func deleteGodoosFiles(dir string) error { + return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && filepath.Base(path)[:7] == ".godoos." { + if err := os.Remove(path); err != nil { + return err + } + } + return nil + }) +} // RenameVectorDb 更改指定名称的 VectorList 的数据库名称 func RenameVectorDb(oldName string, newName string) error { @@ -91,7 +247,11 @@ func RenameVectorDb(oldName string, newName string) error { if err := model.Db.Model(&model.VecList{}).Where("id = ?", oldList.ID).Update("file_path", newPath).Error; err != nil { return fmt.Errorf("failed to update vector list: %w", err) } - + // Update MapFilePathMonitors + //delete(MapFilePathMonitors, oldPath) + go RemoveWatchFolder(oldPath) + //MapFilePathMonitors[newPath] = oldList.ID + go AddWatchFolder(newPath, oldList.ID, nil) return nil } @@ -111,392 +271,90 @@ func GetVector(id uint) (model.VecList, error) { return vectorList, nil } -// func SimilaritySearch(query string, numDocuments int, collection string, where map[string]string) ([]vs.Document, error) { -// ef := v.embeddingFunc -// if embeddingFunc != nil { -// ef = embeddingFunc -// } - -// q, err := ef(ctx, query) -// if err != nil { -// return nil, fmt.Errorf("failed to compute embedding: %w", err) -// } - -// qv, err := sqlitevec.SerializeFloat32(q) -// if err != nil { -// return nil, fmt.Errorf("failed to serialize query embedding: %w", err) -// } - -// var docs []vs.Document -// err = v.db.Transaction(func(tx *gorm.DB) error { -// // Query matching document IDs and distances -// rows, err := tx.Raw(fmt.Sprintf(` -// SELECT document_id, distance -// FROM [%s_vec] -// WHERE embedding MATCH ? -// ORDER BY distance -// LIMIT ? -// `, collection), qv, numDocuments).Rows() -// if err != nil { -// return fmt.Errorf("failed to query vector table: %w", err) -// } -// defer rows.Close() - -// for rows.Next() { -// var docID string -// var distance float32 -// if err := rows.Scan(&docID, &distance); err != nil { -// return fmt.Errorf("failed to scan row: %w", err) -// } -// docs = append(docs, vs.Document{ -// ID: docID, -// SimilarityScore: 1 - distance, // Higher score means closer match -// }) -// } - -// // Fetch content and metadata for each document -// for i, doc := range docs { -// var content string -// var metadataJSON []byte -// err := tx.Raw(fmt.Sprintf(` -// SELECT content, metadata -// FROM [%s] -// WHERE id = ? -// `, v.embeddingsTableName), doc.ID).Row().Scan(&content, &metadataJSON) -// if err != nil { -// return fmt.Errorf("failed to query embeddings table for document %s: %w", doc.ID, err) -// } - -// var metadata map[string]interface{} -// if err := json.Unmarshal(metadataJSON, &metadata); err != nil { -// return fmt.Errorf("failed to parse metadata for document %s: %w", doc.ID, err) -// } - -// docs[i].Content = content -// docs[i].Metadata = metadata -// } - -// return nil -// }) - -// if err != nil { -// return nil, err -// } - -// return docs, nil -// } - -// func AddDocuments(docs []VectorDoc, collection string) ([]string, error) { -// ids := make([]string, len(docs)) - -// err := VecDb.Transaction(func(tx *gorm.DB) error { -// if len(docs) > 0 { -// valuePlaceholders := make([]string, len(docs)) -// args := make([]interface{}, 0, len(docs)*2) // 2 args per doc: document_id and embedding - -// for i, doc := range docs { -// emb, err := v.embeddingFunc(ctx, doc.Content) -// if err != nil { -// return fmt.Errorf("failed to compute embedding for document %s: %w", doc.ID, err) -// } - -// serializedEmb, err := sqlitevec.SerializeFloat32(emb) -// if err != nil { -// return fmt.Errorf("failed to serialize embedding for document %s: %w", doc.ID, err) -// } - -// valuePlaceholders[i] = "(?, ?)" -// args = append(args, doc.ID, serializedEmb) - -// ids[i] = doc.ID -// } - -// // Raw query for *_vec as gorm doesn't support virtual tables -// query := fmt.Sprintf(` -// INSERT INTO [%s_vec] (document_id, embedding) -// VALUES %s -// `, collection, strings.Join(valuePlaceholders, ", ")) - -// if err := tx.Exec(query, args...).Error; err != nil { -// return fmt.Errorf("failed to batch insert into vector table: %w", err) -// } -// } - -// embs := make([]map[string]interface{}, len(docs)) -// for i, doc := range docs { -// metadataJson, err := json.Marshal(doc.Metadata) -// if err != nil { -// return fmt.Errorf("failed to marshal metadata for document %s: %w", doc.ID, err) -// } -// embs[i] = map[string]interface{}{ -// "id": doc.ID, -// "collection_id": collection, -// "content": doc.Content, -// "metadata": metadataJson, -// } -// } - -// // Use GORM's Create for the embeddings table -// if err := tx.Table(v.embeddingsTableName).Create(embs).Error; err != nil { -// return fmt.Errorf("failed to batch insert into embeddings table: %w", err) -// } - -// return nil -// }) - -// if err != nil { -// return nil, err -// } - -// return ids, nil -// } -//func init() { - -// dbPath := libs.GetVectorDb() -// sqlite_vec.Auto() - -// db, err := sqlx.Connect("sqlite3", dbPath) -// if err != nil { -// fmt.Println("Failed to open SQLite database:", err) -// return -// } -// defer db.Close() -// VecDb = db -// dsn := "file:" + dbPath -// db, err := sqlite3.Open(dsn) -// //db, err := sqlite3.Open(":memory:") -// if err != nil { -// fmt.Println("Failed to open SQLite database:", err) -// return -// } -// stmt, _, err := db.Prepare(`SELECT vec_version()`) -// if err != nil { -// log.Fatal(err) -// } - -// stmt.Step() -// log.Printf("vec_version=%s\n", stmt.ColumnText(0)) -// stmt.Close() -// _, err = db.Exec("CREATE TABLE IF NOT EXISTS vec_list (id INTEGER PRIMARY KEY AUTOINCREMENT,file_path TEXT NOT NULL,engine TEXT NOT NULL,embedding_model TEXT NOT NULL)") -// if err != nil { -// log.Fatal(err) -// } -// _, err = db.Exec("CREATE TABLE IF NOT EXISTS vec_doc (id INTEGER PRIMARY KEY AUTOINCREMENT,list_id INTEGER DEFAULT 0,file_path TEXT,content TEXT)") -// if err != nil { -// log.Fatal(err) -// } -// _, err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[768])") -// if err != nil { -// log.Fatal(err) -// } -// VecDb = db - -//InitMonitor() -//} - -// func HandlerCreateKnowledge(w http.ResponseWriter, r *http.Request) { -// var req VectorList -// err := json.NewDecoder(r.Body).Decode(&req) -// if err != nil { -// libs.ErrorMsg(w, "the chat request error:"+err.Error()) -// return -// } -// if req.FilePath == "" { -// libs.ErrorMsg(w, "file path is empty") -// return -// } -// basePath, err := libs.GetOsDir() -// if err != nil { -// libs.ErrorMsg(w, "get vector db path error:"+err.Error()) -// return -// } -// req.FilePath = filepath.Join(basePath, req.FilePath) - -// // id, err := CreateVector(req) -// // if err != nil { -// // libs.ErrorMsg(w, err.Error()) -// // return -// // } -// // libs.SuccessMsg(w, id, "create vector success") -// } - -// CreateVector 创建一个新的 VectorList 记录 -// func CreateVector(data VectorList) (uint, error) { -// if data.FilePath == "" { -// return 0, fmt.Errorf("file path is empty") -// } -// if data.Engine == "" { -// return 0, fmt.Errorf("engine is empty") -// } - -// if !libs.PathExists(data.FilePath) { -// return 0, fmt.Errorf("file path does not exist") -// } -// if data.EmbeddingModel == "" { -// return 0, fmt.Errorf("embedding model is empty") -// } - -// // Create the new VectorList -// if err := tx.Table("vec_list").Create(data).Error; err != nil { -// return fmt.Errorf("failed to batch insert into embeddings table: %w", err) -// } -// // Get the ID of the newly created VectorList -// vectorID, err := result.LastInsertId() -// if err != nil { -// return 0, err -// } - -// // Start background tasks -// go office.SetDocument(data.FilePath, uint(vectorID)) - -// return uint(vectorID), nil -// } - -// // DeleteVector 删除指定id的 VectorList 记录 -// func DeleteVector(id int) error { -// tx, err := VecDb.Begin() -// if err != nil { -// return err -// } -// defer tx.Rollback() - -// // Delete from vec_doc first -// _, err = tx.Exec("DELETE FROM vec_doc WHERE list_id = ?)", id) -// if err != nil { -// return err -// } - -// // Delete from vec_list -// result, err := tx.Exec("DELETE FROM vec_list WHERE id = ?", id) -// if err != nil { -// return err -// } - -// rowsAffected, err := result.RowsAffected() -// if err != nil { -// return err -// } -// if rowsAffected == 0 { -// return fmt.Errorf("vector list not found") -// } - -// return tx.Commit() -// } - -// // RenameVectorDb 更改指定名称的 VectorList 的数据库名称 -// func RenameVectorDb(oldName string, newName string) error { -// basePath, err := libs.GetOsDir() -// if err != nil { -// return fmt.Errorf("failed to find old vector list: %w", err) -// } - -// // 2. 获取旧的 VectorList 记录 -// var oldList VectorList -// oldPath := filepath.Join(basePath, oldName) -// err = VecDb.QueryRow("SELECT id FROM vec_list WHERE file_path = ?", oldPath).Scan(&oldList.ID) -// if err != nil { -// return fmt.Errorf("failed to find old vector list: %w", err) -// } -// MapFilePathMonitors[oldPath] = 0 - -// // 5. 更新 VectorList 记录中的 DbPath 和 Name -// newPath := filepath.Join(basePath, newName) -// _, err = VecDb.Exec("UPDATE vec_list SET file_path = ? WHERE id = ?", newPath, oldList.ID) -// if err != nil { -// return fmt.Errorf("failed to update vector list: %w", err) -// } -// MapFilePathMonitors[newPath] = oldList.ID - -// return nil -// } -// func InsertVectorDoc(data []VectorDoc, embedlist [][]float32) error { -// rowIds := map[int][]float32{} -// for i, v := range data { -// err := VecDb.Exec("INSERT INTO vec_doc (list_id, file_path, content) VALUES (?, ?, ?)", v.ListID, v.FilePath, v.Content) -// if err != nil { -// return err -// } -// rowID, err := result.LastInsertRowID() -// if err != nil { -// return err -// } -// rowid := int(rowID) -// rowIds[rowid] = embedlist[i] -// } -// stmt, err := VecDb.Prepare("INSERT INTO vec_items(rowid, embedding) VALUES (?, ?)") -// if err != nil { -// log.Fatal(err) -// } -// defer stmt.Close() - -// for id, values := range rowIds { -// v, err := sqlite_vec.SerializeFloat32(values) -// if err != nil { -// log.Fatal(err) -// } -// err = stmt.BindInt64(1, int64(id)) -// if err != nil { -// log.Fatal(err) -// } -// err = stmt.BindBlob(2, v) -// if err != nil { -// log.Fatal(err) -// } -// err = stmt.Exec() -// if err != nil { -// log.Fatal(err) -// } -// stmt.Reset() -// } +// handleGodoosFile 处理 .godoos 文件 +func handleGodoosFile(filePath string, knowledgeId uint) error { + //log.Printf("========Handling .godoos file: %s", filePath) + baseName := filepath.Base(filePath) + // 检查文件后缀是否为 .exe + if strings.HasSuffix(baseName, ".exe") { + log.Printf("Skipping .exe file: %s", filePath) + return nil + } + // 检查是否为 .godoos 文件 + if strings.HasPrefix(baseName, ".godoos.") { + // 去掉 .godoos. 前缀和 .json 后缀 + //fileName := strings.TrimSuffix(strings.TrimPrefix(baseName, ".godoos."), ".json") + // 提取实际文件名部分 + //actualFileName := extractFileName(fileName) + + // 读取文件内容 + content, err := os.ReadFile(filePath) + if err != nil { + return err + } + + // 解析 JSON 内容 + var doc office.Document + err = json.Unmarshal(content, &doc) + if err != nil { + return err + } + + // 检查 Split 是否为空 + if len(doc.Split) == 0 { + return fmt.Errorf("invalid .godoos file: %s", filePath) + } + + // 获取向量数据 + knowData, err := GetVector(knowledgeId) + if err != nil { + return err + } + // 拼接文件名和内容 + // var filterDocs []string + // for _, res := range doc.Split { + // filterDocs = append(filterDocs, fmt.Sprintf("%s %s", actualFileName, res)) + // } + // 获取嵌入向量 + resList, err := server.GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, doc.Split) + if err != nil { + return err + } + + // 检查嵌入向量长度是否匹配 + if len(resList) != len(doc.Split) { + return fmt.Errorf("invalid file len: %s, expected %d embeddings, got %d", filePath, len(doc.Split), len(resList)) + } + + // 创建 VecDoc 列表 + var vectordocs []model.VecDoc + for _, res := range doc.Split { + //log.Printf("Adding document: %s", res) + vectordoc := model.VecDoc{ + Content: res, + FilePath: filePath, + ListID: knowledgeId, + } + vectordocs = append(vectordocs, vectordoc) + } + + // 添加文档 + err = model.Adddocument(knowledgeId, vectordocs, resList) + return err + } else { + // 处理非 .godoos 文件 + if baseName[:1] != "." { + office.ProcessFile(filePath, knowledgeId) + } + return nil + } +} -// return nil -// } -// func InitMonitor() { -// list := GetVectorList() -// for _, v := range list { -// MapFilePathMonitors[v.FilePath] = v.ID +// func extractFileName(fileName string) string { +// // 假设文件名格式为:21.GodoOS企业版介绍 +// parts := strings.SplitN(fileName, ".", 3) +// if len(parts) < 2 { +// return fileName // } -// FolderMonitor() -// } - -// func GetVectorList() []VectorList { -// var vectorList []VectorList -// // stmt, _, err := VecDb.Prepare("SELECT id, file_path, engine, embedding_model FROM vec_list") -// // if err != nil { -// // fmt.Println("Failed to get vector list:", err) -// // return vectorList -// // } -// // stmt.Step() -// // log.Printf("vec_version=%s\n", stmt.ColumnText(0)) -// // stmt.Close() -// // defer rows.Close() - -// // for rows.Next() { -// // var v VectorList -// // err := rows.Scan(&v.ID, &v.FilePath, &v.Engine, &v.EmbeddingModel) -// // if err != nil { -// // fmt.Println("Failed to scan vector list row:", err) -// // continue -// // } -// // vectorList = append(vectorList, v) -// // } - -// return vectorList -// } -// func GetVector(id uint) VectorList { -// var vectorList VectorList -// // sql := "SELECT id, file_path, engine, embedding_model FROM vec_list WHERE id = " + fmt.Sprintf("%d", id) -// // stmt, _, err := VecDb.Prepare(sql) -// // if err != nil { -// // fmt.Println("Failed to get vector list:", err) -// // return vectorList -// // } -// // stmt.Step() -// // log.Printf("vec_version=%s\n", stmt.ColumnText(0)) -// // stmt.Close() -// // err := VecDb.QueryRow("SELECT id, file_path, engine, embedding_model FROM vec_list WHERE id = ?", id).Scan(&vectorList.ID, &vectorList.FilePath, &vectorList.Engine, &vectorList.EmbeddingModel) -// // if err != nil { -// // fmt.Println("Failed to get vector:", err) -// // } -// return vectorList +// return parts[1] // } diff --git a/godo/cmd/main.go b/godo/cmd/main.go index de5b368..a8851ed 100644 --- a/godo/cmd/main.go +++ b/godo/cmd/main.go @@ -52,6 +52,7 @@ func OsStart() { } db.InitDB() proxy.InitProxyHandlers() + vector.InitMonitorVector() webdav.InitWebdav() router := mux.NewRouter() router.Use(recoverMiddleware) @@ -151,6 +152,8 @@ func OsStart() { aiRouter.HandleFunc("/embeddings", model.EmbeddingHandler).Methods(http.MethodPost) aiRouter.HandleFunc("/searchweb", search.SearchWebhandler).Methods(http.MethodGet) aiRouter.HandleFunc("/addknowledge", vector.HandlerCreateKnowledge).Methods(http.MethodPost) + aiRouter.HandleFunc("/askknowledge", vector.HandlerAskKnowledge).Methods(http.MethodPost) + aiRouter.HandleFunc("/deleteknowledge", vector.HandlerDelKnowledge).Methods(http.MethodGet) //注册浏览器路由 ieRouter := router.PathPrefix("/ie").Subrouter() ieRouter.HandleFunc("/navigate", store.HandleNavigate).Methods(http.MethodGet) diff --git a/godo/files/fs.go b/godo/files/fs.go index e3875ce..664197b 100644 --- a/godo/files/fs.go +++ b/godo/files/fs.go @@ -68,32 +68,6 @@ func HandleReadDir(w http.ResponseWriter, r *http.Request) { return } - // 如果是文件,读取内容 - if osFileInfo.IsFile { - file, err := os.Open(filepath.Join(basePath, osFileInfo.Path)) - if err != nil { - libs.HTTPError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to open file: %v", err)) - return - } - defer file.Close() - - content, err := io.ReadAll(file) - if err != nil { - libs.HTTPError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to read file content: %v", err)) - return - } - - osFileInfo.Content = string(content) - // 检查文件内容是否以"link::"开头 - if strings.HasPrefix(osFileInfo.Content, "link::") { - osFileInfo.IsSymlink = true - } else { - osFileInfo.Content = "" - } - - osFileInfo.IsPwd = IsPwdFile(content) - } - osFileInfos = append(osFileInfos, *osFileInfo) } // 按照 ModTime 进行降序排序 @@ -126,7 +100,6 @@ func HandleStat(w http.ResponseWriter, r *http.Request) { libs.HTTPError(w, http.StatusNotFound, err.Error()) return } - // fmt.Printf("basePath: %+s", basePath) if osFileInfo.IsFile { // 是否为加密文件 file, err := os.Open(filepath.Join(basePath, path)) @@ -140,19 +113,12 @@ func HandleStat(w http.ResponseWriter, r *http.Request) { if err == io.EOF { // EOF说明文件大小小于34字节 osFileInfo.IsPwd = false - res := libs.APIResponse{ - Message: "File information retrieved successfully.", - Data: osFileInfo, - } - json.NewEncoder(w).Encode(res) - return } - libs.HTTPError(w, http.StatusInternalServerError, err.Error()) - return + } else { + osFileInfo.IsPwd = IsPwdFile(buffer) } - osFileInfo.IsPwd = IsPwdFile(buffer) - } + } res := libs.APIResponse{ Message: "File information retrieved successfully.", Data: osFileInfo, @@ -217,16 +183,7 @@ func HandleUnlink(w http.ResponseWriter, r *http.Request) { // HandleClear removes the entire filesystem (Caution: Use with care!) func HandleClear(w http.ResponseWriter, r *http.Request) { - // basePath, err := libs.GetOsDir() - // if err != nil { - // libs.HTTPError(w, http.StatusInternalServerError, err.Error()) - // return - // } - // err = Clear(basePath) - // if err != nil { - // libs.HTTPError(w, http.StatusConflict, err.Error()) - // return - // } + err := RecoverOsSystem() if err != nil { libs.HTTPError(w, http.StatusInternalServerError, err.Error()) @@ -475,6 +432,7 @@ func parseMode(modeStr string) (os.FileMode, error) { return os.FileMode(mode), nil } func HandleDesktop(w http.ResponseWriter, r *http.Request) { + //log.Printf("=====Received request: %v", r) rootInfo, err := GetDesktop() if err != nil { libs.HTTPError(w, http.StatusInternalServerError, err.Error()) diff --git a/godo/files/os.go b/godo/files/os.go index 348f2a2..6523039 100644 --- a/godo/files/os.go +++ b/godo/files/os.go @@ -33,23 +33,24 @@ import ( // Common response structure type OsFileInfo struct { - IsFile bool `json:"isFile"` - IsDir bool `json:"isDirectory"` - IsSymlink bool `json:"isSymlink"` - Size int64 `json:"size"` - ModTime time.Time `json:"modTime"` - AccessTime time.Time `json:"atime"` - CreateTime time.Time `json:"birthtime"` - Mode os.FileMode `json:"mode"` - Name string `json:"name"` // 文件名 - Path string `json:"path"` // 文件路径 - OldPath string `json:"oldPath"` // 旧的文件路径 - ParentPath string `json:"parentPath"` // 父目录路径 - Content string `json:"content"` // 文件内容 - Ext string `json:"ext"` // 文件扩展名 - Title string `json:"title"` // 文件名(不包含扩展名) - ID int `json:"id,omitempty"` // 文件ID(可选) - IsPwd bool `json:"isPwd"` // 是否加密 + IsFile bool `json:"isFile"` + IsDir bool `json:"isDirectory"` + IsSymlink bool `json:"isSymlink"` + IsKnowledge bool `json:"isKnowledge"` + Size int64 `json:"size"` + ModTime time.Time `json:"modTime"` + AccessTime time.Time `json:"atime"` + CreateTime time.Time `json:"birthtime"` + Mode os.FileMode `json:"mode"` + Name string `json:"name"` // 文件名 + Path string `json:"path"` // 文件路径 + OldPath string `json:"oldPath"` // 旧的文件路径 + ParentPath string `json:"parentPath"` // 父目录路径 + Content string `json:"content"` // 文件内容 + Ext string `json:"ext"` // 文件扩展名 + Title string `json:"title"` // 文件名(不包含扩展名) + ID int `json:"id,omitempty"` // 文件ID(可选) + IsPwd bool `json:"isPwd"` // 是否加密 } // validateFilePath 验证路径不为空 @@ -272,6 +273,12 @@ func GetFileInfo(entry interface{}, basePath, parentPath string) (*OsFileInfo, e osFileInfo.ParentPath = parentPath osFileInfo.Title = strings.TrimSuffix(osFileInfo.Name, filepath.Ext(osFileInfo.Name)) osFileInfo.Ext = strings.TrimPrefix(filepath.Ext(osFileInfo.Name), ".") + if osFileInfo.IsDir { + knowledgeFilePath := filepath.Join(filePath, ".knowledge") + if libs.PathExists(knowledgeFilePath) { + osFileInfo.IsKnowledge = true + } + } return osFileInfo, nil } @@ -342,8 +349,5 @@ func IsPwdFile(fileData []byte) bool { // 使用正则表达式验证中间的字符串是否为 MD5 加密的字符串 md5Regex := regexp.MustCompile(`^[a-fA-F0-9]{32}$`) - if !md5Regex.MatchString(middleStr) { - return false - } - return true + return md5Regex.MatchString(middleStr) } diff --git a/godo/model/init.go b/godo/model/init.go index ae4706a..411132e 100644 --- a/godo/model/init.go +++ b/godo/model/init.go @@ -17,6 +17,14 @@ func InitDB() { if err != nil { return } + // Enable PRAGMAs + // - busy_timeout (ms) to prevent db lockups as we're accessing the DB from multiple separate processes in otto8 + tx := db.Exec(` +PRAGMA busy_timeout = 10000; +`) + if tx.Error != nil { + return + } Db = db // 自动迁移模式 db.AutoMigrate(&SysDisk{}) diff --git a/godo/model/vec_doc.go b/godo/model/vec_doc.go index ad82c98..ea82c64 100644 --- a/godo/model/vec_doc.go +++ b/godo/model/vec_doc.go @@ -1,14 +1,148 @@ package model -import "gorm.io/gorm" +import ( + "fmt" + "sort" + + sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" + "gorm.io/gorm" +) type VecDoc struct { gorm.Model Content string `json:"content"` FilePath string `json:"file_path" gorm:"not null"` - ListID int `json:"list_id"` + ListID uint `json:"list_id"` } func (VecDoc) TableName() string { return "vec_doc" } + +func Adddocument(listId uint, docs []VecDoc, embeds [][]float32) error { + // 批量删除具有相同 file_path 的 VecDoc 数据 + filePath := docs[0].FilePath + // 删除向量表中的数据 + var existingDocs []VecDoc + if err := Db.Where("file_path = ?", filePath).Find(&existingDocs).Error; err != nil { + return fmt.Errorf("failed to find existing documents: %v", err) + } + + for _, existingDoc := range existingDocs { + documentID := fmt.Sprintf("%d", existingDoc.ID) + result := Db.Exec(fmt.Sprintf("DELETE FROM [%d_vec] WHERE document_id = ?", listId), documentID) + if result.Error != nil { + return fmt.Errorf("failed to delete vector data: %v", result.Error) + } + } + + // 删除 vec_doc 中的数据(硬删除) + if err := Db.Unscoped().Where("file_path = ?", filePath).Delete(&VecDoc{}).Error; err != nil { + return fmt.Errorf("failed to delete vec_doc data: %v", err) + } + // 批量插入新的 vec_doc 数据 + if err := Db.CreateInBatches(docs, 100).Error; err != nil { + return fmt.Errorf("failed to create vec_doc data: %v", err) + } + + // 批量插入向量数据到虚拟表 + for i, doc := range docs { + v, err := sqlite_vec.SerializeFloat32(embeds[i]) + if err != nil { + return fmt.Errorf("failed to serialize vector: %v", err) + } + + documentID := fmt.Sprintf("%d", doc.ID) + result := Db.Exec(fmt.Sprintf("INSERT INTO [%d_vec] (document_id, embedding) VALUES (?, ?)", listId), documentID, v) + if result.Error != nil { + return fmt.Errorf("failed to insert vector data: %v", result.Error) + } + } + + return nil +} + +func Deletedocument(listId uint, filePath string) error { + var existingDocs []VecDoc + if err := Db.Where("file_path = ?", filePath).Find(&existingDocs).Error; err != nil { + return fmt.Errorf("failed to find existing documents: %v", err) + } + + for _, existingDoc := range existingDocs { + documentID := fmt.Sprintf("%d", existingDoc.ID) + result := Db.Exec(fmt.Sprintf("DELETE FROM [%d_vec] WHERE document_id = ?", listId), documentID) + if result.Error != nil { + return fmt.Errorf("failed to delete vector data: %v", result.Error) + } + } + + // 删除 vec_doc 中的数据(硬删除) + if err := Db.Unscoped().Where("file_path = ?", filePath).Delete(&VecDoc{}).Error; err != nil { + return fmt.Errorf("failed to delete vec_doc data: %v", err) + } + + return nil +} + +type AskDocResponse struct { + Content string `json:"content"` + Score float32 `json:"score"` + FilePath string `json:"file_path"` +} + +func AskDocument(listId uint, query []float32) ([]AskDocResponse, error) { + // 序列化查询向量 + queryVec, err := sqlite_vec.SerializeFloat32(query) + if err != nil { + return []AskDocResponse{}, fmt.Errorf("failed to serialize query vector: %v", err) + } + + // 查询最相似的文档 + var results []struct { + DocumentID uint `gorm:"column:document_id"` + Distance float32 `gorm:"column:distance"` + } + result := Db.Raw(fmt.Sprintf(` + SELECT + document_id, + distance + FROM [%d_vec] + WHERE embedding MATCH ? + ORDER BY distance + LIMIT 10 + `, listId), queryVec).Scan(&results) + + if result.Error != nil { + return nil, fmt.Errorf("failed to query vector data: %v", result.Error) + } + + if len(results) == 0 { + return nil, fmt.Errorf("no matching documents found") + } + + // 获取最相似的文档 + var docs []VecDoc + var docIDs []uint + for _, res := range results { + docIDs = append(docIDs, res.DocumentID) + } + + if err := Db.Where("id IN ?", docIDs).Find(&docs).Error; err != nil { + return nil, fmt.Errorf("failed to find documents: %v", err) + } + + // 构建响应 + var responses []AskDocResponse + for i, doc := range docs { + responses = append(responses, AskDocResponse{ + Content: doc.Content, + Score: results[i].Distance, + FilePath: doc.FilePath, + }) + } + // 按 Score 降序排序 + sort.Slice(responses, func(i, j int) bool { + return responses[i].Score > responses[j].Score + }) + return responses, nil +} diff --git a/godo/model/vec_list.go b/godo/model/vec_list.go index b28aba6..be58a92 100644 --- a/godo/model/vec_list.go +++ b/godo/model/vec_list.go @@ -2,7 +2,6 @@ package model import ( "fmt" - "log" "gorm.io/gorm" ) @@ -54,7 +53,7 @@ func CreateVirtualTable(db *gorm.DB, vectorID uint, embeddingSize int) error { embedding float[%d] distance_metric=cosine ) `, vectorID, embeddingSize) - log.Printf("sql: %s", sql) + //log.Printf("sql: %s", sql) return db.Exec(sql).Error }