mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
4.3 KiB
154 lines
4.3 KiB
package model
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
|
|
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
|
|
)
|
|
|
|
type VecDoc struct {
|
|
BaseModel
|
|
Content string `json:"content"`
|
|
FilePath string `json:"file_path" gorm:"not null"`
|
|
FileName string `json:"file_name"`
|
|
ListID uint `json:"list_id"`
|
|
}
|
|
|
|
func (VecDoc) TableName() string {
|
|
return "vec_doc"
|
|
}
|
|
|
|
func Adddocument(listId uint, docs []VecDoc, embeds [][]float32) error {
|
|
// 批量删除具有相同 file_path 的 VecDoc 数据
|
|
filePath := docs[0].FilePath
|
|
// 删除向量表中的数据
|
|
var existingDocs []VecDoc
|
|
if err := Db.Where("file_path = ?", filePath).Find(&existingDocs).Error; err != nil {
|
|
return fmt.Errorf("failed to find existing documents: %v", err)
|
|
}
|
|
|
|
for _, existingDoc := range existingDocs {
|
|
documentID := fmt.Sprintf("%d", existingDoc.ID)
|
|
result := Db.Exec(fmt.Sprintf("DELETE FROM [%d_vec] WHERE document_id = ?", listId), documentID)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to delete vector data: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
// 删除 vec_doc 中的数据(硬删除)
|
|
if err := Db.Unscoped().Where("file_path = ?", filePath).Delete(&VecDoc{}).Error; err != nil {
|
|
return fmt.Errorf("failed to delete vec_doc data: %v", err)
|
|
}
|
|
// 批量插入新的 vec_doc 数据
|
|
if err := Db.CreateInBatches(docs, 100).Error; err != nil {
|
|
return fmt.Errorf("failed to create vec_doc data: %v", err)
|
|
}
|
|
|
|
// 批量插入向量数据到虚拟表
|
|
for i, doc := range docs {
|
|
v, err := sqlite_vec.SerializeFloat32(embeds[i])
|
|
if err != nil {
|
|
return fmt.Errorf("failed to serialize vector: %v", err)
|
|
}
|
|
|
|
documentID := fmt.Sprintf("%d", doc.ID)
|
|
result := Db.Exec(fmt.Sprintf("INSERT INTO [%d_vec] (document_id, embedding) VALUES (?, ?)", listId), documentID, v)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to insert vector data: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func Deletedocument(listId uint, filePath string) error {
|
|
var existingDocs []VecDoc
|
|
if err := Db.Where("file_path = ?", filePath).Find(&existingDocs).Error; err != nil {
|
|
return fmt.Errorf("failed to find existing documents: %v", err)
|
|
}
|
|
|
|
for _, existingDoc := range existingDocs {
|
|
documentID := fmt.Sprintf("%d", existingDoc.ID)
|
|
result := Db.Exec(fmt.Sprintf("DELETE FROM [%d_vec] WHERE document_id = ?", listId), documentID)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to delete vector data: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
// 删除 vec_doc 中的数据(硬删除)
|
|
if err := Db.Unscoped().Where("file_path = ?", filePath).Delete(&VecDoc{}).Error; err != nil {
|
|
return fmt.Errorf("failed to delete vec_doc data: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type AskDocResponse struct {
|
|
Content string `json:"content"`
|
|
Score float32 `json:"score"`
|
|
FilePath string `json:"file_path"`
|
|
FileName string `json:"file_name"`
|
|
}
|
|
type AskRequest struct {
|
|
ID uint `json:"id"`
|
|
Input string `json:"input"`
|
|
}
|
|
|
|
func AskDocument(listId uint, query []float32) ([]AskDocResponse, error) {
|
|
// 序列化查询向量
|
|
queryVec, err := sqlite_vec.SerializeFloat32(query)
|
|
if err != nil {
|
|
return []AskDocResponse{}, fmt.Errorf("failed to serialize query vector: %v", err)
|
|
}
|
|
|
|
// 查询最相似的文档
|
|
var results []struct {
|
|
DocumentID uint `gorm:"column:document_id"`
|
|
Distance float32 `gorm:"column:distance"`
|
|
}
|
|
result := Db.Raw(fmt.Sprintf(`
|
|
SELECT
|
|
document_id,
|
|
distance
|
|
FROM [%d_vec]
|
|
WHERE embedding MATCH ?
|
|
ORDER BY distance
|
|
LIMIT 10
|
|
`, listId), queryVec).Scan(&results)
|
|
|
|
if result.Error != nil {
|
|
return nil, fmt.Errorf("failed to query vector data: %v", result.Error)
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
return nil, fmt.Errorf("no matching documents found")
|
|
}
|
|
|
|
// 获取最相似的文档
|
|
var docs []VecDoc
|
|
var docIDs []uint
|
|
for _, res := range results {
|
|
docIDs = append(docIDs, res.DocumentID)
|
|
}
|
|
|
|
if err := Db.Where("id IN ?", docIDs).Find(&docs).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to find documents: %v", err)
|
|
}
|
|
|
|
// 构建响应
|
|
var responses []AskDocResponse
|
|
for i, doc := range docs {
|
|
responses = append(responses, AskDocResponse{
|
|
Content: doc.Content,
|
|
Score: results[i].Distance,
|
|
FilePath: doc.FilePath,
|
|
FileName: doc.FileName,
|
|
})
|
|
}
|
|
// 按 Score 降序排序
|
|
sort.Slice(responses, func(i, j int) bool {
|
|
return responses[i].Score > responses[j].Score
|
|
})
|
|
return responses, nil
|
|
}
|
|
|