mirror of https://gitee.com/godoos/godoos.git
21 changed files with 2455 additions and 453 deletions
@ -0,0 +1,409 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"fmt" |
|||
"path/filepath" |
|||
"slices" |
|||
"sync" |
|||
) |
|||
|
|||
// Collection 表示一个文档集合。
|
|||
// 它还包含一个配置好的嵌入函数,当添加没有嵌入的文档时会使用该函数。
|
|||
type Collection struct { |
|||
Name string |
|||
|
|||
metadata map[string]string |
|||
documents map[string]*Document |
|||
documentsLock sync.RWMutex |
|||
embed EmbeddingFunc |
|||
|
|||
persistDirectory string |
|||
compress bool |
|||
|
|||
// ⚠️ 当添加字段时,请考虑在 [DB.Export] 和 [DB.Import] 的持久化结构中添加相应的字段
|
|||
} |
|||
|
|||
// 我们不导出这个函数,以保持 API 表面最小。
|
|||
// 用户通过 [Client.CreateCollection] 创建集合。
|
|||
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) { |
|||
// 复制元数据以避免在创建集合后调用者修改元数据时发生数据竞争。
|
|||
m := make(map[string]string, len(metadata)) |
|||
for k, v := range metadata { |
|||
m[k] = v |
|||
} |
|||
|
|||
c := &Collection{ |
|||
Name: name, |
|||
|
|||
metadata: m, |
|||
documents: make(map[string]*Document), |
|||
embed: embed, |
|||
} |
|||
|
|||
// 持久化
|
|||
if dbDir != "" { |
|||
safeName := hash2hex(name) |
|||
c.persistDirectory = filepath.Join(dbDir, safeName) |
|||
c.compress = compress |
|||
// 持久化名称和元数据
|
|||
metadataPath := filepath.Join(c.persistDirectory, metadataFileName) |
|||
metadataPath += ".gob" |
|||
if c.compress { |
|||
metadataPath += ".gz" |
|||
} |
|||
pc := struct { |
|||
Name string |
|||
Metadata map[string]string |
|||
}{ |
|||
Name: name, |
|||
Metadata: m, |
|||
} |
|||
err := persistToFile(metadataPath, pc, compress, "") |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法持久化集合元数据: %w", err) |
|||
} |
|||
} |
|||
|
|||
return c, nil |
|||
} |
|||
|
|||
// 添加嵌入到数据存储中。
|
|||
//
|
|||
// - ids: 要添加的嵌入的 ID
|
|||
// - embeddings: 要添加的嵌入。如果为 nil,则基于内容使用集合的嵌入函数计算嵌入。可选。
|
|||
// - metadatas: 与嵌入关联的元数据。查询时可以过滤这些元数据。可选。
|
|||
// - contents: 与嵌入关联的内容。
|
|||
//
|
|||
// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。
|
|||
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error { |
|||
return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1) |
|||
} |
|||
|
|||
// AddConcurrently 类似于 Add,但并发地添加嵌入。
|
|||
// 这在没有传递任何嵌入时特别有用,因为需要创建嵌入。
|
|||
// 出现错误时,取消所有并发操作并返回错误。
|
|||
//
|
|||
// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。
|
|||
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error { |
|||
if len(ids) == 0 { |
|||
return errors.New("ids 为空") |
|||
} |
|||
if len(embeddings) == 0 && len(contents) == 0 { |
|||
return errors.New("必须填写 embeddings 或 contents") |
|||
} |
|||
if len(embeddings) != 0 { |
|||
if len(embeddings) != len(ids) { |
|||
return errors.New("ids 和 embeddings 的长度必须相同") |
|||
} |
|||
} else { |
|||
// 分配空切片以便稍后通过索引访问
|
|||
embeddings = make([][]float32, len(ids)) |
|||
} |
|||
if len(metadatas) != 0 { |
|||
if len(ids) != len(metadatas) { |
|||
return errors.New("当 metadatas 不为空时,其长度必须与 ids 相同") |
|||
} |
|||
} else { |
|||
// 分配空切片以便稍后通过索引访问
|
|||
metadatas = make([]map[string]string, len(ids)) |
|||
} |
|||
if len(contents) != 0 { |
|||
if len(contents) != len(ids) { |
|||
return errors.New("ids 和 contents 的长度必须相同") |
|||
} |
|||
} else { |
|||
// 分配空切片以便稍后通过索引访问
|
|||
contents = make([]string, len(ids)) |
|||
} |
|||
if concurrency < 1 { |
|||
return errors.New("并发数必须至少为 1") |
|||
} |
|||
|
|||
// 将 Chroma 风格的参数转换为文档切片
|
|||
docs := make([]Document, 0, len(ids)) |
|||
for i, id := range ids { |
|||
docs = append(docs, Document{ |
|||
ID: id, |
|||
Metadata: metadatas[i], |
|||
Embedding: embeddings[i], |
|||
Content: contents[i], |
|||
}) |
|||
} |
|||
|
|||
return c.AddDocuments(ctx, docs, concurrency) |
|||
} |
|||
|
|||
// AddDocuments 使用指定的并发数将文档添加到集合中。
|
|||
// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。
|
|||
// 出现错误时,取消所有并发操作并返回错误。
|
|||
func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error { |
|||
if len(documents) == 0 { |
|||
// TODO: 这是否应为无操作(no-op)?
|
|||
return errors.New("documents 切片为空") |
|||
} |
|||
if concurrency < 1 { |
|||
return errors.New("并发数必须至少为 1") |
|||
} |
|||
// 对于其他验证,我们依赖于 AddDocument。
|
|||
|
|||
var sharedErr error |
|||
sharedErrLock := sync.Mutex{} |
|||
ctx, cancel := context.WithCancelCause(ctx) |
|||
defer cancel(nil) |
|||
setSharedErr := func(err error) { |
|||
sharedErrLock.Lock() |
|||
defer sharedErrLock.Unlock() |
|||
// 另一个 goroutine 可能已经设置了错误。
|
|||
if sharedErr == nil { |
|||
sharedErr = err |
|||
// 取消所有其他 goroutine 的操作。
|
|||
cancel(sharedErr) |
|||
} |
|||
} |
|||
|
|||
var wg sync.WaitGroup |
|||
semaphore := make(chan struct{}, concurrency) |
|||
for _, doc := range documents { |
|||
wg.Add(1) |
|||
go func(doc Document) { |
|||
defer wg.Done() |
|||
|
|||
// 如果另一个 goroutine 已经失败,则不开始。
|
|||
if ctx.Err() != nil { |
|||
return |
|||
} |
|||
|
|||
// 等待直到 $concurrency 个其他 goroutine 正在创建文档。
|
|||
semaphore <- struct{}{} |
|||
defer func() { <-semaphore }() |
|||
|
|||
err := c.AddDocument(ctx, doc) |
|||
if err != nil { |
|||
setSharedErr(fmt.Errorf("无法添加文档 '%s': %w", doc.ID, err)) |
|||
return |
|||
} |
|||
}(doc) |
|||
} |
|||
|
|||
wg.Wait() |
|||
|
|||
return sharedErr |
|||
} |
|||
|
|||
// AddDocument 将文档添加到集合中。
|
|||
// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。
|
|||
func (c *Collection) AddDocument(ctx context.Context, doc Document) error { |
|||
if doc.ID == "" { |
|||
return errors.New("文档 ID 为空") |
|||
} |
|||
if len(doc.Embedding) == 0 && doc.Content == "" { |
|||
return errors.New("必须填写文档的 embedding 或 content") |
|||
} |
|||
|
|||
// 复制元数据以避免在创建文档后调用者修改元数据时发生数据竞争。
|
|||
m := make(map[string]string, len(doc.Metadata)) |
|||
for k, v := range doc.Metadata { |
|||
m[k] = v |
|||
} |
|||
|
|||
// 如果嵌入不存在,则创建嵌入,否则如果需要则规范化
|
|||
if len(doc.Embedding) == 0 { |
|||
embedding, err := c.embed(ctx, doc.Content) |
|||
if err != nil { |
|||
return fmt.Errorf("无法创建文档的嵌入: %w", err) |
|||
} |
|||
doc.Embedding = embedding |
|||
} else { |
|||
if !isNormalized(doc.Embedding) { |
|||
doc.Embedding = normalizeVector(doc.Embedding) |
|||
} |
|||
} |
|||
|
|||
c.documentsLock.Lock() |
|||
// 我们不使用 defer 解锁,因为我们希望尽早解锁。
|
|||
c.documents[doc.ID] = &doc |
|||
c.documentsLock.Unlock() |
|||
|
|||
// 持久化文档
|
|||
if c.persistDirectory != "" { |
|||
docPath := c.getDocPath(doc.ID) |
|||
err := persistToFile(docPath, doc, c.compress, "") |
|||
if err != nil { |
|||
return fmt.Errorf("无法将文档持久化到 %q: %w", docPath, err) |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Delete 从集合中删除文档。
|
|||
//
|
|||
// - where: 元数据的条件过滤。可选。
|
|||
// - whereDocument: 文档的条件过滤。可选。
|
|||
// - ids: 要删除的文档的 ID。如果为空,则删除所有文档。
|
|||
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error { |
|||
// 必须至少有一个 where、whereDocument 或 ids
|
|||
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 { |
|||
return fmt.Errorf("必须至少有一个 where、whereDocument 或 ids") |
|||
} |
|||
|
|||
if len(c.documents) == 0 { |
|||
return nil |
|||
} |
|||
|
|||
for k := range whereDocument { |
|||
if !slices.Contains(supportedFilters, k) { |
|||
return errors.New("不支持的 whereDocument 操作符") |
|||
} |
|||
} |
|||
|
|||
var docIDs []string |
|||
|
|||
c.documentsLock.Lock() |
|||
defer c.documentsLock.Unlock() |
|||
|
|||
if where != nil || whereDocument != nil { |
|||
// 元数据 + 内容过滤
|
|||
filteredDocs := filterDocs(c.documents, where, whereDocument) |
|||
for _, doc := range filteredDocs { |
|||
docIDs = append(docIDs, doc.ID) |
|||
} |
|||
} else { |
|||
docIDs = ids |
|||
} |
|||
|
|||
// 如果没有剩余的文档,则不执行操作
|
|||
if len(docIDs) == 0 { |
|||
return nil |
|||
} |
|||
|
|||
for _, docID := range docIDs { |
|||
delete(c.documents, docID) |
|||
|
|||
// 从磁盘删除文档
|
|||
if c.persistDirectory != "" { |
|||
docPath := c.getDocPath(docID) |
|||
err := removeFile(docPath) |
|||
if err != nil { |
|||
return fmt.Errorf("无法删除文档 %q: %w", docPath, err) |
|||
} |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Count 返回集合中的文档数量。
|
|||
func (c *Collection) Count() int { |
|||
c.documentsLock.RLock() |
|||
defer c.documentsLock.RUnlock() |
|||
return len(c.documents) |
|||
} |
|||
|
|||
// Result 表示查询结果中的单个结果。
|
|||
type Result struct { |
|||
ID string |
|||
Metadata map[string]string |
|||
Embedding []float32 |
|||
Content string |
|||
|
|||
// 查询与文档之间的余弦相似度。
|
|||
// 值越高,文档与查询越相似。
|
|||
// 值的范围是 [-1, 1]。
|
|||
Similarity float32 |
|||
} |
|||
|
|||
// 在集合上执行详尽的最近邻搜索。
|
|||
//
|
|||
// - queryText: 要搜索的文本。其嵌入将使用集合的嵌入函数创建。
|
|||
// - nResults: 要返回的结果数量。必须大于 0。
|
|||
// - where: 元数据的条件过滤。可选。
|
|||
// - whereDocument: 文档的条件过滤。可选。
|
|||
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { |
|||
if queryText == "" { |
|||
return nil, errors.New("queryText 为空") |
|||
} |
|||
|
|||
queryVectors, err := c.embed(ctx, queryText) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法创建查询的嵌入: %w", err) |
|||
} |
|||
|
|||
return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument) |
|||
} |
|||
|
|||
// 在集合上执行详尽的最近邻搜索。
|
|||
//
|
|||
// - queryEmbedding: 要搜索的查询的嵌入。必须使用与集合中文档嵌入相同的嵌入模型创建。
|
|||
// - nResults: 要返回的结果数量。必须大于 0。
|
|||
// - where: 元数据的条件过滤。可选。
|
|||
// - whereDocument: 文档的条件过滤。可选。
|
|||
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { |
|||
if len(queryEmbedding) == 0 { |
|||
return nil, errors.New("queryEmbedding 为空") |
|||
} |
|||
if nResults <= 0 { |
|||
return nil, errors.New("nResults 必须大于 0") |
|||
} |
|||
c.documentsLock.RLock() |
|||
defer c.documentsLock.RUnlock() |
|||
// if nResults > len(c.documents) {
|
|||
// return nil, errors.New("nResults 必须小于或等于集合中的文档数量")
|
|||
// }
|
|||
|
|||
if len(c.documents) == 0 { |
|||
return nil, nil |
|||
} |
|||
|
|||
// 验证 whereDocument 操作符
|
|||
for k := range whereDocument { |
|||
if !slices.Contains(supportedFilters, k) { |
|||
return nil, errors.New("不支持的操作符") |
|||
} |
|||
} |
|||
|
|||
// 根据元数据和内容过滤文档
|
|||
filteredDocs := filterDocs(c.documents, where, whereDocument) |
|||
|
|||
// 如果过滤器删除了所有文档,则不继续
|
|||
if len(filteredDocs) == 0 { |
|||
return nil, nil |
|||
} |
|||
|
|||
// 对于剩余的文档,获取最相似的文档。
|
|||
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法获取最相似的文档: %w", err) |
|||
} |
|||
length := len(nMaxDocs) |
|||
if length > nResults { |
|||
length = nResults |
|||
} |
|||
res := make([]Result, 0, length) |
|||
for i := 0; i < length; i++ { |
|||
doc := c.documents[nMaxDocs[i].docID] |
|||
res = append(res, Result{ |
|||
ID: nMaxDocs[i].docID, |
|||
Metadata: doc.Metadata, |
|||
Embedding: doc.Embedding, |
|||
Content: doc.Content, |
|||
Similarity: nMaxDocs[i].similarity, |
|||
}) |
|||
} |
|||
|
|||
// 返回前 nResults 个结果
|
|||
return res, nil |
|||
} |
|||
|
|||
// getDocPath 生成文档文件的路径。
|
|||
func (c *Collection) getDocPath(docID string) string { |
|||
safeID := hash2hex(docID) |
|||
docPath := filepath.Join(c.persistDirectory, safeID) |
|||
docPath += ".gob" |
|||
if c.compress { |
|||
docPath += ".gz" |
|||
} |
|||
return docPath |
|||
} |
@ -0,0 +1,412 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"fmt" |
|||
"godo/libs" |
|||
"io" |
|||
"io/fs" |
|||
"os" |
|||
"path/filepath" |
|||
"strings" |
|||
"sync" |
|||
) |
|||
|
|||
// EmbeddingFunc 是一个为给定文本创建嵌入的函数。
|
|||
// 默认使用 OpenAI 的 "text-embedding-3-small" 模型。
|
|||
// 该函数必须返回一个已归一化的向量。
|
|||
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) |
|||
|
|||
// DB 包含多个集合,每个集合包含多个文档。
|
|||
type DB struct { |
|||
collections map[string]*Collection |
|||
collectionsLock sync.RWMutex |
|||
|
|||
persistDirectory string |
|||
compress bool |
|||
} |
|||
|
|||
// NewDB 创建一个新的内存中的数据库。
|
|||
func NewDB() *DB { |
|||
return &DB{ |
|||
collections: make(map[string]*Collection), |
|||
} |
|||
} |
|||
|
|||
// NewPersistentDB 创建一个新的持久化的数据库。
|
|||
// 如果路径为空,默认为 "./godoos/data/godoDB"。
|
|||
// 如果 compress 为 true,则文件将使用 gzip 压缩。
|
|||
func NewPersistentDB(path string, compress bool) (*DB, error) { |
|||
homeDir, err := libs.GetAppDir() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if path == "" { |
|||
path = filepath.Join(homeDir, "data", "godoDB") |
|||
} else { |
|||
path = filepath.Clean(path) |
|||
} |
|||
|
|||
ext := ".gob" |
|||
if compress { |
|||
ext += ".gz" |
|||
} |
|||
|
|||
db := &DB{ |
|||
collections: make(map[string]*Collection), |
|||
persistDirectory: path, |
|||
compress: compress, |
|||
} |
|||
|
|||
fi, err := os.Stat(path) |
|||
if err != nil { |
|||
if errors.Is(err, fs.ErrNotExist) { |
|||
err := os.MkdirAll(path, 0o700) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法创建持久化目录: %w", err) |
|||
} |
|||
return db, nil |
|||
} |
|||
return nil, fmt.Errorf("无法获取持久化目录信息: %w", err) |
|||
} else if !fi.IsDir() { |
|||
return nil, fmt.Errorf("路径不是目录: %s", path) |
|||
} |
|||
|
|||
dirEntries, err := os.ReadDir(path) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法读取持久化目录: %w", err) |
|||
} |
|||
for _, dirEntry := range dirEntries { |
|||
if !dirEntry.IsDir() { |
|||
continue |
|||
} |
|||
collectionPath := filepath.Join(path, dirEntry.Name()) |
|||
collectionDirEntries, err := os.ReadDir(collectionPath) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法读取集合目录: %w", err) |
|||
} |
|||
c := &Collection{ |
|||
documents: make(map[string]*Document), |
|||
persistDirectory: collectionPath, |
|||
compress: compress, |
|||
} |
|||
for _, collectionDirEntry := range collectionDirEntries { |
|||
if collectionDirEntry.IsDir() { |
|||
continue |
|||
} |
|||
fPath := filepath.Join(collectionPath, collectionDirEntry.Name()) |
|||
if collectionDirEntry.Name() == metadataFileName+ext { |
|||
pc := struct { |
|||
Name string |
|||
Metadata map[string]string |
|||
}{} |
|||
err := readFromFile(fPath, &pc, "") |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法读取集合元数据: %w", err) |
|||
} |
|||
c.Name = pc.Name |
|||
c.metadata = pc.Metadata |
|||
} else if strings.HasSuffix(collectionDirEntry.Name(), ext) { |
|||
d := &Document{} |
|||
err := readFromFile(fPath, d, "") |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法读取文档: %w", err) |
|||
} |
|||
c.documents[d.ID] = d |
|||
} |
|||
} |
|||
if c.Name == "" && len(c.documents) == 0 { |
|||
continue |
|||
} |
|||
if c.Name == "" { |
|||
return nil, fmt.Errorf("未找到集合元数据文件: %s", collectionPath) |
|||
} |
|||
db.collections[c.Name] = c |
|||
} |
|||
|
|||
return db, nil |
|||
} |
|||
|
|||
// ImportFromFile 从给定路径的文件导入数据库。
|
|||
func (db *DB) ImportFromFile(filePath string, encryptionKey string) error { |
|||
if filePath == "" { |
|||
return fmt.Errorf("文件路径为空") |
|||
} |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须为 32 字节长") |
|||
} |
|||
|
|||
fi, err := os.Stat(filePath) |
|||
if err != nil { |
|||
if errors.Is(err, fs.ErrNotExist) { |
|||
return fmt.Errorf("文件不存在: %s", filePath) |
|||
} |
|||
return fmt.Errorf("无法获取文件信息: %w", err) |
|||
} else if fi.IsDir() { |
|||
return fmt.Errorf("路径是目录: %s", filePath) |
|||
} |
|||
|
|||
type persistenceCollection struct { |
|||
Name string |
|||
Metadata map[string]string |
|||
Documents map[string]*Document |
|||
} |
|||
persistenceDB := struct { |
|||
Collections map[string]*persistenceCollection |
|||
}{ |
|||
Collections: make(map[string]*persistenceCollection, len(db.collections)), |
|||
} |
|||
|
|||
db.collectionsLock.Lock() |
|||
defer db.collectionsLock.Unlock() |
|||
|
|||
err = readFromFile(filePath, &persistenceDB, encryptionKey) |
|||
if err != nil { |
|||
return fmt.Errorf("无法读取文件: %w", err) |
|||
} |
|||
|
|||
for _, pc := range persistenceDB.Collections { |
|||
c := &Collection{ |
|||
Name: pc.Name, |
|||
metadata: pc.Metadata, |
|||
documents: pc.Documents, |
|||
} |
|||
if db.persistDirectory != "" { |
|||
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) |
|||
c.compress = db.compress |
|||
} |
|||
db.collections[c.Name] = c |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// ImportFromReader 从 reader 导入数据库。
|
|||
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error { |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须为 32 字节长") |
|||
} |
|||
|
|||
type persistenceCollection struct { |
|||
Name string |
|||
Metadata map[string]string |
|||
Documents map[string]*Document |
|||
} |
|||
persistenceDB := struct { |
|||
Collections map[string]*persistenceCollection |
|||
}{ |
|||
Collections: make(map[string]*persistenceCollection, len(db.collections)), |
|||
} |
|||
|
|||
db.collectionsLock.Lock() |
|||
defer db.collectionsLock.Unlock() |
|||
|
|||
err := readFromReader(reader, &persistenceDB, encryptionKey) |
|||
if err != nil { |
|||
return fmt.Errorf("无法读取流: %w", err) |
|||
} |
|||
|
|||
for _, pc := range persistenceDB.Collections { |
|||
c := &Collection{ |
|||
Name: pc.Name, |
|||
metadata: pc.Metadata, |
|||
documents: pc.Documents, |
|||
} |
|||
if db.persistDirectory != "" { |
|||
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) |
|||
c.compress = db.compress |
|||
} |
|||
db.collections[c.Name] = c |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// ExportToFile 将数据库导出到给定路径的文件。
|
|||
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error { |
|||
if filePath == "" { |
|||
filePath = "./gododb.gob" |
|||
if compress { |
|||
filePath += ".gz" |
|||
} |
|||
if encryptionKey != "" { |
|||
filePath += ".enc" |
|||
} |
|||
} |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须为 32 字节长") |
|||
} |
|||
|
|||
type persistenceCollection struct { |
|||
Name string |
|||
Metadata map[string]string |
|||
Documents map[string]*Document |
|||
} |
|||
persistenceDB := struct { |
|||
Collections map[string]*persistenceCollection |
|||
}{ |
|||
Collections: make(map[string]*persistenceCollection, len(db.collections)), |
|||
} |
|||
|
|||
db.collectionsLock.RLock() |
|||
defer db.collectionsLock.RUnlock() |
|||
|
|||
for k, v := range db.collections { |
|||
persistenceDB.Collections[k] = &persistenceCollection{ |
|||
Name: v.Name, |
|||
Metadata: v.metadata, |
|||
Documents: v.documents, |
|||
} |
|||
} |
|||
|
|||
err := persistToFile(filePath, persistenceDB, compress, encryptionKey) |
|||
if err != nil { |
|||
return fmt.Errorf("无法导出数据库: %w", err) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// ExportToWriter 将数据库导出到 writer。
|
|||
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error { |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须为 32 字节长") |
|||
} |
|||
|
|||
type persistenceCollection struct { |
|||
Name string |
|||
Metadata map[string]string |
|||
Documents map[string]*Document |
|||
} |
|||
persistenceDB := struct { |
|||
Collections map[string]*persistenceCollection |
|||
}{ |
|||
Collections: make(map[string]*persistenceCollection, len(db.collections)), |
|||
} |
|||
|
|||
db.collectionsLock.RLock() |
|||
defer db.collectionsLock.RUnlock() |
|||
|
|||
for k, v := range db.collections { |
|||
persistenceDB.Collections[k] = &persistenceCollection{ |
|||
Name: v.Name, |
|||
Metadata: v.metadata, |
|||
Documents: v.documents, |
|||
} |
|||
} |
|||
|
|||
err := persistToWriter(writer, persistenceDB, compress, encryptionKey) |
|||
if err != nil { |
|||
return fmt.Errorf("无法导出数据库: %w", err) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// CreateCollection 创建具有给定名称和元数据的新集合。
|
|||
func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { |
|||
if name == "" { |
|||
return nil, errors.New("集合名称为空") |
|||
} |
|||
if embeddingFunc == nil { |
|||
embeddingFunc = NewEmbeddingFuncDefault() |
|||
} |
|||
collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory, db.compress) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法创建集合: %w", err) |
|||
} |
|||
|
|||
db.collectionsLock.Lock() |
|||
defer db.collectionsLock.Unlock() |
|||
db.collections[name] = collection |
|||
return collection, nil |
|||
} |
|||
|
|||
// ListCollections 返回数据库中的所有集合。
|
|||
func (db *DB) ListCollections() map[string]*Collection { |
|||
db.collectionsLock.RLock() |
|||
defer db.collectionsLock.RUnlock() |
|||
|
|||
res := make(map[string]*Collection, len(db.collections)) |
|||
for k, v := range db.collections { |
|||
res[k] = v |
|||
} |
|||
|
|||
return res |
|||
} |
|||
|
|||
// GetCollection 返回具有给定名称的集合。
|
|||
func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection { |
|||
db.collectionsLock.RLock() |
|||
defer db.collectionsLock.RUnlock() |
|||
|
|||
c, ok := db.collections[name] |
|||
if !ok { |
|||
return nil |
|||
} |
|||
|
|||
if c.embed == nil { |
|||
if embeddingFunc == nil { |
|||
c.embed = NewEmbeddingFuncDefault() |
|||
} else { |
|||
c.embed = embeddingFunc |
|||
} |
|||
} |
|||
return c |
|||
} |
|||
|
|||
// GetOrCreateCollection 返回数据库中已有的集合,或创建一个新的集合。
|
|||
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { |
|||
collection := db.GetCollection(name, embeddingFunc) |
|||
if collection == nil { |
|||
var err error |
|||
collection, err = db.CreateCollection(name, metadata, embeddingFunc) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法创建集合: %w", err) |
|||
} |
|||
} |
|||
return collection, nil |
|||
} |
|||
|
|||
// DeleteCollection 删除具有给定名称的集合。
|
|||
func (db *DB) DeleteCollection(name string) error { |
|||
db.collectionsLock.Lock() |
|||
defer db.collectionsLock.Unlock() |
|||
|
|||
col, ok := db.collections[name] |
|||
if !ok { |
|||
return nil |
|||
} |
|||
|
|||
if db.persistDirectory != "" { |
|||
collectionPath := col.persistDirectory |
|||
err := os.RemoveAll(collectionPath) |
|||
if err != nil { |
|||
return fmt.Errorf("无法删除集合目录: %w", err) |
|||
} |
|||
} |
|||
|
|||
delete(db.collections, name) |
|||
return nil |
|||
} |
|||
|
|||
// Reset 从数据库中移除所有集合。
|
|||
func (db *DB) Reset() error { |
|||
db.collectionsLock.Lock() |
|||
defer db.collectionsLock.Unlock() |
|||
|
|||
if db.persistDirectory != "" { |
|||
err := os.RemoveAll(db.persistDirectory) |
|||
if err != nil { |
|||
return fmt.Errorf("无法删除持久化目录: %w", err) |
|||
} |
|||
err = os.MkdirAll(db.persistDirectory, 0o700) |
|||
if err != nil { |
|||
return fmt.Errorf("无法重新创建持久化目录: %w", err) |
|||
} |
|||
} |
|||
|
|||
db.collections = make(map[string]*Collection) |
|||
return nil |
|||
} |
@ -1,178 +1,52 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"context" |
|||
"errors" |
|||
"fmt" |
|||
"godo/ai/server" |
|||
"godo/libs" |
|||
"godo/office" |
|||
"log" |
|||
"os" |
|||
"path/filepath" |
|||
"strings" |
|||
|
|||
"github.com/fsnotify/fsnotify" |
|||
) |
|||
|
|||
var MapFilePathMonitors = map[string]uint{} |
|||
|
|||
func FolderMonitor() { |
|||
basePath, err := libs.GetOsDir() |
|||
if err != nil { |
|||
log.Printf("Error getting base path: %s", err.Error()) |
|||
return |
|||
} |
|||
watcher, err := fsnotify.NewWatcher() |
|||
if err != nil { |
|||
log.Printf("Error creating watcher: %s", err.Error()) |
|||
return |
|||
} |
|||
defer watcher.Close() |
|||
|
|||
// 递归添加所有子目录
|
|||
addRecursive(basePath, watcher) |
|||
|
|||
// Start listening for events.
|
|||
go func() { |
|||
for { |
|||
select { |
|||
case event, ok := <-watcher.Events: |
|||
if !ok { |
|||
log.Println("error:", err) |
|||
return |
|||
} |
|||
//log.Println("event:", event)
|
|||
filePath := event.Name |
|||
result, knowledgeId := shouldProcess(filePath) |
|||
//log.Printf("result:%d,knowledgeId:%d", result, knowledgeId)
|
|||
if result > 0 { |
|||
info, err := os.Stat(filePath) |
|||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { |
|||
log.Println("modified file:", filePath) |
|||
if !info.IsDir() { |
|||
handleGodoosFile(filePath, knowledgeId) |
|||
} |
|||
} |
|||
if event.Has(fsnotify.Create) || event.Has(fsnotify.Rename) { |
|||
// 处理创建或重命名事件,添加新目录
|
|||
if err == nil && info.IsDir() { |
|||
addRecursive(filePath, watcher) |
|||
} |
|||
} |
|||
if event.Has(fsnotify.Remove) { |
|||
// 处理删除事件,移除目录
|
|||
if err == nil && info.IsDir() { |
|||
watcher.Remove(filePath) |
|||
} |
|||
} |
|||
} |
|||
case err, ok := <-watcher.Errors: |
|||
if !ok { |
|||
return |
|||
} |
|||
log.Println("error:", err) |
|||
} |
|||
} |
|||
}() |
|||
// Document 表示单个文档。
|
|||
type Document struct { |
|||
ID string // 文档的唯一标识符
|
|||
Metadata map[string]string // 文档的元数据
|
|||
Embedding []float32 // 文档的嵌入向量
|
|||
Content string // 文档的内容
|
|||
|
|||
// Add a path.
|
|||
err = watcher.Add(basePath) |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
// Block main goroutine forever.
|
|||
<-make(chan struct{}) |
|||
// ⚠️ 当在此处添加未导出字段时,请考虑在 [DB.Export] 和 [DB.Import] 中添加一个持久化结构版本。
|
|||
} |
|||
|
|||
func shouldProcess(filePath string) (int, uint) { |
|||
// 规范化路径
|
|||
filePath = filepath.Clean(filePath) |
|||
|
|||
// 检查文件路径是否在 MapFilePathMonitors 中
|
|||
for path, id := range MapFilePathMonitors { |
|||
if id < 1 { |
|||
return 0, 0 |
|||
} |
|||
path = filepath.Clean(path) |
|||
if filePath == path { |
|||
return 1, id // 完全相等
|
|||
} |
|||
if strings.HasPrefix(filePath, path+string(filepath.Separator)) { |
|||
return 2, id // 包含
|
|||
} |
|||
// NewDocument 创建一个新的文档,包括其嵌入向量。
|
|||
// 元数据是可选的。
|
|||
// 如果未提供嵌入向量,则使用嵌入函数创建。
|
|||
// 如果内容为空但需要存储嵌入向量,可以仅提供嵌入向量。
|
|||
// 如果 embeddingFunc 为 nil,则使用默认的嵌入函数。
|
|||
//
|
|||
// 如果你想创建没有嵌入向量的文档,例如让 [Collection.AddDocuments] 并发创建它们,
|
|||
// 可以使用 `chromem.Document{...}` 而不是这个构造函数。
|
|||
func NewDocument(ctx context.Context, id string, metadata map[string]string, embedding []float32, content string, embeddingFunc EmbeddingFunc) (Document, error) { |
|||
if id == "" { |
|||
return Document{}, errors.New("ID 不能为空") |
|||
} |
|||
if len(embedding) == 0 && content == "" { |
|||
return Document{}, errors.New("嵌入向量或内容必须至少有一个非空") |
|||
} |
|||
if embeddingFunc == nil { |
|||
embeddingFunc = NewEmbeddingFuncDefault() |
|||
} |
|||
return 0, 0 // 不存在
|
|||
} |
|||
|
|||
func addRecursive(path string, watcher *fsnotify.Watcher) { |
|||
err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { |
|||
if len(embedding) == 0 { |
|||
var err error |
|||
embedding, err = embeddingFunc(ctx, content) |
|||
if err != nil { |
|||
log.Printf("Error walking path %s: %v", path, err) |
|||
return err |
|||
return Document{}, fmt.Errorf("无法生成嵌入向量: %w", err) |
|||
} |
|||
if info.IsDir() { |
|||
result, _ := shouldProcess(path) |
|||
if result > 0 { |
|||
if err := watcher.Add(path); err != nil { |
|||
log.Printf("Error adding path %s to watcher: %v", path, err) |
|||
return err |
|||
} |
|||
log.Printf("Added path %s to watcher", path) |
|||
} |
|||
|
|||
} |
|||
return nil |
|||
}) |
|||
if err != nil { |
|||
log.Printf("Error adding recursive paths: %v", err) |
|||
} |
|||
} |
|||
|
|||
func handleGodoosFile(filePath string, knowledgeId uint) error { |
|||
log.Printf("========Handling .godoos file: %s", filePath) |
|||
baseName := filepath.Base(filePath) |
|||
if baseName[:8] != ".godoos." { |
|||
if baseName[:1] != "." { |
|||
office.ProcessFile(filePath, knowledgeId) |
|||
} |
|||
return nil |
|||
} |
|||
var doc office.Document |
|||
content, err := os.ReadFile(filePath) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
err = json.Unmarshal(content, &doc) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if len(doc.Split) == 0 { |
|||
return fmt.Errorf("invalid .godoos file: %s", filePath) |
|||
} |
|||
knowData := GetVector(knowledgeId) |
|||
resList, err := server.GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, doc.Split) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if len(resList) != len(doc.Split) { |
|||
return fmt.Errorf("invalid file len: %s, expected %d embeddings, got %d", filePath, len(doc.Split), len(resList)) |
|||
} |
|||
// var vectordocs []model.Vectordoc
|
|||
// for i, res := range resList {
|
|||
// //log.Printf("res: %v", res)
|
|||
// vectordoc := model.Vectordoc{
|
|||
// Content: doc.Split[i],
|
|||
// Embed: res,
|
|||
// FilePath: filePath,
|
|||
// KnowledgeID: knowledgeId,
|
|||
// Pos: fmt.Sprintf("%d", i),
|
|||
// }
|
|||
// vectordocs = append(vectordocs, vectordoc)
|
|||
// }
|
|||
// result := vectorListDb.Create(&vectordocs)
|
|||
// if result.Error != nil {
|
|||
// return result.Error
|
|||
// }
|
|||
return nil |
|||
return Document{ |
|||
ID: id, |
|||
Metadata: metadata, |
|||
Embedding: embedding, |
|||
Content: content, |
|||
}, nil |
|||
} |
|||
|
@ -0,0 +1,181 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"fmt" |
|||
"godo/ai/server" |
|||
"godo/libs" |
|||
"godo/office" |
|||
"log" |
|||
"os" |
|||
"path/filepath" |
|||
"strings" |
|||
|
|||
"github.com/fsnotify/fsnotify" |
|||
) |
|||
|
|||
var MapFilePathMonitors = map[string]uint{} |
|||
|
|||
func FolderMonitor() { |
|||
basePath, err := libs.GetOsDir() |
|||
if err != nil { |
|||
log.Printf("Error getting base path: %s", err.Error()) |
|||
return |
|||
} |
|||
watcher, err := fsnotify.NewWatcher() |
|||
if err != nil { |
|||
log.Printf("Error creating watcher: %s", err.Error()) |
|||
return |
|||
} |
|||
defer watcher.Close() |
|||
|
|||
// 递归添加所有子目录
|
|||
addRecursive(basePath, watcher) |
|||
|
|||
// Start listening for events.
|
|||
go func() { |
|||
for { |
|||
select { |
|||
case event, ok := <-watcher.Events: |
|||
if !ok { |
|||
log.Println("error:", err) |
|||
return |
|||
} |
|||
//log.Println("event:", event)
|
|||
filePath := event.Name |
|||
result, knowledgeId := shouldProcess(filePath) |
|||
//log.Printf("result:%d,knowledgeId:%d", result, knowledgeId)
|
|||
if result > 0 { |
|||
info, err := os.Stat(filePath) |
|||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { |
|||
log.Println("modified file:", filePath) |
|||
if !info.IsDir() { |
|||
handleGodoosFile(filePath, knowledgeId) |
|||
} |
|||
} |
|||
if event.Has(fsnotify.Create) || event.Has(fsnotify.Rename) { |
|||
// 处理创建或重命名事件,添加新目录
|
|||
if err == nil && info.IsDir() { |
|||
addRecursive(filePath, watcher) |
|||
} |
|||
} |
|||
if event.Has(fsnotify.Remove) { |
|||
// 处理删除事件,移除目录
|
|||
if err == nil && info.IsDir() { |
|||
watcher.Remove(filePath) |
|||
} |
|||
} |
|||
} |
|||
case err, ok := <-watcher.Errors: |
|||
if !ok { |
|||
return |
|||
} |
|||
log.Println("error:", err) |
|||
} |
|||
} |
|||
}() |
|||
|
|||
// Add a path.
|
|||
err = watcher.Add(basePath) |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
// Block main goroutine forever.
|
|||
<-make(chan struct{}) |
|||
} |
|||
|
|||
func shouldProcess(filePath string) (int, uint) { |
|||
// 规范化路径
|
|||
filePath = filepath.Clean(filePath) |
|||
|
|||
// 检查文件路径是否在 MapFilePathMonitors 中
|
|||
for path, id := range MapFilePathMonitors { |
|||
if id < 1 { |
|||
return 0, 0 |
|||
} |
|||
path = filepath.Clean(path) |
|||
if filePath == path { |
|||
return 1, id // 完全相等
|
|||
} |
|||
if strings.HasPrefix(filePath, path+string(filepath.Separator)) { |
|||
return 2, id // 包含
|
|||
} |
|||
} |
|||
return 0, 0 // 不存在
|
|||
} |
|||
|
|||
func addRecursive(path string, watcher *fsnotify.Watcher) { |
|||
err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { |
|||
if err != nil { |
|||
log.Printf("Error walking path %s: %v", path, err) |
|||
return err |
|||
} |
|||
if info.IsDir() { |
|||
result, _ := shouldProcess(path) |
|||
if result > 0 { |
|||
if err := watcher.Add(path); err != nil { |
|||
log.Printf("Error adding path %s to watcher: %v", path, err) |
|||
return err |
|||
} |
|||
log.Printf("Added path %s to watcher", path) |
|||
} |
|||
|
|||
} |
|||
return nil |
|||
}) |
|||
if err != nil { |
|||
log.Printf("Error adding recursive paths: %v", err) |
|||
} |
|||
} |
|||
|
|||
func handleGodoosFile(filePath string, knowledgeId uint) error { |
|||
log.Printf("========Handling .godoos file: %s", filePath) |
|||
baseName := filepath.Base(filePath) |
|||
if baseName[:8] != ".godoos." { |
|||
if baseName[:1] != "." { |
|||
office.ProcessFile(filePath, knowledgeId) |
|||
} |
|||
return nil |
|||
} |
|||
var doc office.Document |
|||
content, err := os.ReadFile(filePath) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
err = json.Unmarshal(content, &doc) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if len(doc.Split) == 0 { |
|||
return fmt.Errorf("invalid .godoos file: %s", filePath) |
|||
} |
|||
knowData, err := GetVector(knowledgeId) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
resList, err := server.GetEmbeddings(knowData.Engine, knowData.EmbeddingModel, doc.Split) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if len(resList) != len(doc.Split) { |
|||
return fmt.Errorf("invalid file len: %s, expected %d embeddings, got %d", filePath, len(doc.Split), len(resList)) |
|||
} |
|||
// var vectordocs []model.Vectordoc
|
|||
// for i, res := range resList {
|
|||
// //log.Printf("res: %v", res)
|
|||
// vectordoc := model.Vectordoc{
|
|||
// Content: doc.Split[i],
|
|||
// Embed: res,
|
|||
// FilePath: filePath,
|
|||
// KnowledgeID: knowledgeId,
|
|||
// Pos: fmt.Sprintf("%d", i),
|
|||
// }
|
|||
// vectordocs = append(vectordocs, vectordoc)
|
|||
// }
|
|||
// result := vectorListDb.Create(&vectordocs)
|
|||
// if result.Error != nil {
|
|||
// return result.Error
|
|||
// }
|
|||
return nil |
|||
} |
@ -0,0 +1,125 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"bytes" |
|||
"context" |
|||
"encoding/json" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"net/http" |
|||
"os" |
|||
"sync" |
|||
) |
|||
|
|||
const BaseURLOpenAI = "https://api.openai.com/v1" |
|||
|
|||
type EmbeddingModelOpenAI string |
|||
|
|||
const ( |
|||
EmbeddingModelOpenAI2Ada EmbeddingModelOpenAI = "text-embedding-ada-002" |
|||
EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small" |
|||
EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large" |
|||
) |
|||
|
|||
type openAIResponse struct { |
|||
Data []struct { |
|||
Embedding []float32 `json:"embedding"` |
|||
} `json:"data"` |
|||
} |
|||
|
|||
// NewEmbeddingFuncDefault 返回一个函数,使用 OpenAI 的 "text-embedding-3-small" 模型通过 API 创建文本嵌入向量。
|
|||
// 该模型支持的最大文本长度为 8191 个标记。
|
|||
// API 密钥从环境变量 "OPENAI_API_KEY" 中读取。
|
|||
func NewEmbeddingFuncDefault() EmbeddingFunc { |
|||
apiKey := os.Getenv("OPENAI_API_KEY") |
|||
return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small) |
|||
} |
|||
|
|||
// NewEmbeddingFuncOpenAI 返回一个函数,使用 OpenAI API 创建文本嵌入向量。
|
|||
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { |
|||
// OpenAI 嵌入向量已归一化
|
|||
normalized := true |
|||
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized) |
|||
} |
|||
|
|||
// NewEmbeddingFuncOpenAICompat 返回一个函数,使用兼容 OpenAI 的 API 创建文本嵌入向量。
|
|||
// 例如:
|
|||
// - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
|||
// - LitLLM: https://github.com/BerriAI/litellm
|
|||
// - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
|||
//
|
|||
// `normalized` 参数表示嵌入模型返回的向量是否已经归一化。如果为 nil,则会在首次请求时自动检测(有小概率向量恰好长度为 1)。
|
|||
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc { |
|||
client := &http.Client{} |
|||
|
|||
var checkedNormalized bool |
|||
checkNormalized := sync.Once{} |
|||
|
|||
return func(ctx context.Context, text string) ([]float32, error) { |
|||
// 准备请求体
|
|||
reqBody, err := json.Marshal(map[string]string{ |
|||
"input": text, |
|||
"model": model, |
|||
}) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法序列化请求体: %w", err) |
|||
} |
|||
|
|||
// 创建带有上下文的请求以支持超时
|
|||
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody)) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法创建请求: %w", err) |
|||
} |
|||
req.Header.Set("Content-Type", "application/json") |
|||
req.Header.Set("Authorization", "Bearer "+apiKey) |
|||
|
|||
// 发送请求
|
|||
resp, err := client.Do(req) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法发送请求: %w", err) |
|||
} |
|||
defer resp.Body.Close() |
|||
|
|||
// 检查响应状态
|
|||
if resp.StatusCode != http.StatusOK { |
|||
return nil, errors.New("嵌入 API 返回错误响应: " + resp.Status) |
|||
} |
|||
|
|||
// 读取并解码响应体
|
|||
body, err := io.ReadAll(resp.Body) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法读取响应体: %w", err) |
|||
} |
|||
var embeddingResponse openAIResponse |
|||
err = json.Unmarshal(body, &embeddingResponse) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("无法反序列化响应体: %w", err) |
|||
} |
|||
|
|||
// 检查响应中是否包含嵌入向量
|
|||
if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 { |
|||
return nil, errors.New("响应中未找到嵌入向量") |
|||
} |
|||
|
|||
v := embeddingResponse.Data[0].Embedding |
|||
if normalized != nil { |
|||
if *normalized { |
|||
return v, nil |
|||
} |
|||
return normalizeVector(v), nil |
|||
} |
|||
checkNormalized.Do(func() { |
|||
if isNormalized(v) { |
|||
checkedNormalized = true |
|||
} else { |
|||
checkedNormalized = false |
|||
} |
|||
}) |
|||
if !checkedNormalized { |
|||
v = normalizeVector(v) |
|||
} |
|||
|
|||
return v, nil |
|||
} |
|||
} |
@ -0,0 +1,208 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"bytes" |
|||
"compress/gzip" |
|||
"crypto/aes" |
|||
"crypto/cipher" |
|||
"crypto/rand" |
|||
"crypto/sha256" |
|||
"encoding/gob" |
|||
"encoding/hex" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"io/fs" |
|||
"os" |
|||
"path/filepath" |
|||
) |
|||
|
|||
const metadataFileName = "00000000" |
|||
|
|||
// hash2hex 将字符串转换为 SHA256 哈希并返回前 8 位的十六进制表示。
|
|||
func hash2hex(name string) string { |
|||
hash := sha256.Sum256([]byte(name)) |
|||
return hex.EncodeToString(hash[:4]) |
|||
} |
|||
|
|||
// persistToFile 将对象持久化到文件。支持 Gob 序列化、Gzip 压缩和 AES-GCM 加密。
|
|||
func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error { |
|||
if filePath == "" { |
|||
return fmt.Errorf("文件路径为空") |
|||
} |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须是 32 字节长") |
|||
} |
|||
|
|||
// 确保父目录存在
|
|||
if err := os.MkdirAll(filepath.Dir(filePath), 0o700); err != nil { |
|||
return fmt.Errorf("无法创建父目录: %w", err) |
|||
} |
|||
|
|||
// 打开或创建文件
|
|||
f, err := os.Create(filePath) |
|||
if err != nil { |
|||
return fmt.Errorf("无法创建文件: %w", err) |
|||
} |
|||
defer f.Close() |
|||
|
|||
return persistToWriter(f, obj, compress, encryptionKey) |
|||
} |
|||
|
|||
// persistToWriter 将对象持久化到 io.Writer。支持 Gob 序列化、Gzip 压缩和 AES-GCM 加密。
|
|||
func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error { |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须是 32 字节长") |
|||
} |
|||
|
|||
var chainedWriter io.Writer |
|||
if encryptionKey == "" { |
|||
chainedWriter = w |
|||
} else { |
|||
chainedWriter = &bytes.Buffer{} |
|||
} |
|||
|
|||
var gzw *gzip.Writer |
|||
var enc *gob.Encoder |
|||
if compress { |
|||
gzw = gzip.NewWriter(chainedWriter) |
|||
enc = gob.NewEncoder(gzw) |
|||
} else { |
|||
enc = gob.NewEncoder(chainedWriter) |
|||
} |
|||
|
|||
if err := enc.Encode(obj); err != nil { |
|||
return fmt.Errorf("无法编码或写入对象: %w", err) |
|||
} |
|||
|
|||
if compress { |
|||
if err := gzw.Close(); err != nil { |
|||
return fmt.Errorf("无法关闭 Gzip 写入器: %w", err) |
|||
} |
|||
} |
|||
|
|||
if encryptionKey == "" { |
|||
return nil |
|||
} |
|||
|
|||
block, err := aes.NewCipher([]byte(encryptionKey)) |
|||
if err != nil { |
|||
return fmt.Errorf("无法创建 AES 密码: %w", err) |
|||
} |
|||
gcm, err := cipher.NewGCM(block) |
|||
if err != nil { |
|||
return fmt.Errorf("无法创建 GCM 包装器: %w", err) |
|||
} |
|||
nonce := make([]byte, gcm.NonceSize()) |
|||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil { |
|||
return fmt.Errorf("无法读取随机字节作为 nonce: %w", err) |
|||
} |
|||
|
|||
buf := chainedWriter.(*bytes.Buffer) |
|||
encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil) |
|||
if _, err := w.Write(encrypted); err != nil { |
|||
return fmt.Errorf("无法写入加密数据: %w", err) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// readFromFile 从文件中读取对象。支持 Gob 反序列化、Gzip 解压和 AES-GCM 解密。
|
|||
func readFromFile(filePath string, obj any, encryptionKey string) error { |
|||
if filePath == "" { |
|||
return fmt.Errorf("文件路径为空") |
|||
} |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须是 32 字节长") |
|||
} |
|||
|
|||
r, err := os.Open(filePath) |
|||
if err != nil { |
|||
return fmt.Errorf("无法打开文件: %w", err) |
|||
} |
|||
defer r.Close() |
|||
|
|||
return readFromReader(r, obj, encryptionKey) |
|||
} |
|||
|
|||
// readFromReader 从 io.Reader 中读取对象。支持 Gob 反序列化、Gzip 解压和 AES-GCM 解密。
|
|||
func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error { |
|||
if encryptionKey != "" && len(encryptionKey) != 32 { |
|||
return errors.New("加密密钥必须是 32 字节长") |
|||
} |
|||
|
|||
var chainedReader io.Reader |
|||
if encryptionKey != "" { |
|||
encrypted, err := io.ReadAll(r) |
|||
if err != nil { |
|||
return fmt.Errorf("无法读取数据: %w", err) |
|||
} |
|||
block, err := aes.NewCipher([]byte(encryptionKey)) |
|||
if err != nil { |
|||
return fmt.Errorf("无法创建 AES 密码: %w", err) |
|||
} |
|||
gcm, err := cipher.NewGCM(block) |
|||
if err != nil { |
|||
return fmt.Errorf("无法创建 GCM 包装器: %w", err) |
|||
} |
|||
nonceSize := gcm.NonceSize() |
|||
if len(encrypted) < nonceSize { |
|||
return fmt.Errorf("加密数据太短") |
|||
} |
|||
nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:] |
|||
data, err := gcm.Open(nil, nonce, ciphertext, nil) |
|||
if err != nil { |
|||
return fmt.Errorf("无法解密数据: %w", err) |
|||
} |
|||
chainedReader = bytes.NewReader(data) |
|||
} else { |
|||
chainedReader = r |
|||
} |
|||
|
|||
magicNumber := make([]byte, 2) |
|||
_, err := chainedReader.Read(magicNumber) |
|||
if err != nil { |
|||
return fmt.Errorf("无法读取魔数以确定是否压缩: %w", err) |
|||
} |
|||
compressed := magicNumber[0] == 0x1f && magicNumber[1] == 0x8b |
|||
|
|||
// 重置读取器位置
|
|||
if s, ok := chainedReader.(io.Seeker); !ok { |
|||
return fmt.Errorf("读取器不支持寻址") |
|||
} else { |
|||
_, err := s.Seek(0, 0) |
|||
if err != nil { |
|||
return fmt.Errorf("无法重置读取器: %w", err) |
|||
} |
|||
} |
|||
|
|||
if compressed { |
|||
gzr, err := gzip.NewReader(chainedReader) |
|||
if err != nil { |
|||
return fmt.Errorf("无法创建 Gzip 读取器: %w", err) |
|||
} |
|||
defer gzr.Close() |
|||
chainedReader = gzr |
|||
} |
|||
|
|||
dec := gob.NewDecoder(chainedReader) |
|||
if err := dec.Decode(obj); err != nil { |
|||
return fmt.Errorf("无法解码对象: %w", err) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// removeFile 删除指定路径的文件。如果文件不存在,则无操作。
|
|||
func removeFile(filePath string) error { |
|||
if filePath == "" { |
|||
return fmt.Errorf("文件路径为空") |
|||
} |
|||
|
|||
err := os.Remove(filePath) |
|||
if err != nil && !errors.Is(err, fs.ErrNotExist) { |
|||
return fmt.Errorf("无法删除文件 %q: %w", filePath, err) |
|||
} |
|||
|
|||
return nil |
|||
} |
@ -0,0 +1,207 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"cmp" |
|||
"container/heap" |
|||
"context" |
|||
"fmt" |
|||
"runtime" |
|||
"slices" |
|||
"strings" |
|||
"sync" |
|||
) |
|||
|
|||
var supportedFilters = []string{"$contains", "$not_contains"} |
|||
|
|||
type docSim struct { |
|||
docID string |
|||
similarity float32 |
|||
} |
|||
|
|||
// docMaxHeap 是基于相似度的最大堆。
|
|||
type docMaxHeap []docSim |
|||
|
|||
func (h docMaxHeap) Len() int { return len(h) } |
|||
func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity } |
|||
func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } |
|||
|
|||
func (h *docMaxHeap) Push(x any) { |
|||
*h = append(*h, x.(docSim)) |
|||
} |
|||
|
|||
func (h *docMaxHeap) Pop() any { |
|||
old := *h |
|||
n := len(old) |
|||
x := old[n-1] |
|||
*h = old[0 : n-1] |
|||
return x |
|||
} |
|||
|
|||
// maxDocSims 管理一个固定大小的最大堆,保存最高的 n 个相似度。并发安全,但 values() 返回的结果不是。
|
|||
type maxDocSims struct { |
|||
h docMaxHeap |
|||
lock sync.RWMutex |
|||
size int |
|||
} |
|||
|
|||
// newMaxDocSims 创建一个新的固定大小的最大堆。
|
|||
func newMaxDocSims(size int) *maxDocSims { |
|||
return &maxDocSims{ |
|||
h: make(docMaxHeap, 0, size), |
|||
size: size, |
|||
} |
|||
} |
|||
|
|||
// add 插入一个新的 docSim 到堆中,保持最高的 n 个相似度。
|
|||
func (mds *maxDocSims) add(doc docSim) { |
|||
mds.lock.Lock() |
|||
defer mds.lock.Unlock() |
|||
if mds.h.Len() < mds.size { |
|||
heap.Push(&mds.h, doc) |
|||
} else if mds.h.Len() > 0 && mds.h[0].similarity < doc.similarity { |
|||
heap.Pop(&mds.h) |
|||
heap.Push(&mds.h, doc) |
|||
} |
|||
} |
|||
|
|||
// values 返回堆中的 docSim,按相似度降序排列。调用是并发安全的,但结果不是。
|
|||
func (d *maxDocSims) values() []docSim { |
|||
d.lock.RLock() |
|||
defer d.lock.RUnlock() |
|||
slices.SortFunc(d.h, func(i, j docSim) int { |
|||
return cmp.Compare(j.similarity, i.similarity) |
|||
}) |
|||
return d.h |
|||
} |
|||
|
|||
// filterDocs 并发过滤文档,根据元数据和内容进行筛选。
|
|||
func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { |
|||
filteredDocs := make([]*Document, 0, len(docs)) |
|||
var filteredDocsLock sync.Mutex |
|||
|
|||
numCPUs := runtime.NumCPU() |
|||
numDocs := len(docs) |
|||
concurrency := min(numCPUs, numDocs) |
|||
|
|||
docChan := make(chan *Document, concurrency*2) |
|||
|
|||
var wg sync.WaitGroup |
|||
for i := 0; i < concurrency; i++ { |
|||
wg.Add(1) |
|||
go func() { |
|||
defer wg.Done() |
|||
for doc := range docChan { |
|||
if documentMatchesFilters(doc, where, whereDocument) { |
|||
filteredDocsLock.Lock() |
|||
filteredDocs = append(filteredDocs, doc) |
|||
filteredDocsLock.Unlock() |
|||
} |
|||
} |
|||
}() |
|||
} |
|||
|
|||
for _, doc := range docs { |
|||
docChan <- doc |
|||
} |
|||
close(docChan) |
|||
|
|||
wg.Wait() |
|||
|
|||
if len(filteredDocs) == 0 { |
|||
return nil |
|||
} |
|||
return filteredDocs |
|||
} |
|||
|
|||
// documentMatchesFilters 检查文档是否匹配给定的过滤条件。
|
|||
func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool { |
|||
for k, v := range where { |
|||
if document.Metadata[k] != v { |
|||
return false |
|||
} |
|||
} |
|||
|
|||
for k, v := range whereDocument { |
|||
switch k { |
|||
case "$contains": |
|||
if !strings.Contains(document.Content, v) { |
|||
return false |
|||
} |
|||
case "$not_contains": |
|||
if strings.Contains(document.Content, v) { |
|||
return false |
|||
} |
|||
} |
|||
} |
|||
|
|||
return true |
|||
} |
|||
|
|||
// getMostSimilarDocs 获取与查询向量最相似的前 n 个文档。
|
|||
func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) { |
|||
nMaxDocs := newMaxDocSims(n) |
|||
|
|||
numCPUs := runtime.NumCPU() |
|||
numDocs := len(docs) |
|||
concurrency := min(numCPUs, numDocs) |
|||
|
|||
var sharedErr error |
|||
var sharedErrLock sync.Mutex |
|||
ctx, cancel := context.WithCancelCause(ctx) |
|||
defer cancel(nil) |
|||
|
|||
setSharedErr := func(err error) { |
|||
sharedErrLock.Lock() |
|||
defer sharedErrLock.Unlock() |
|||
if sharedErr == nil { |
|||
sharedErr = err |
|||
cancel(sharedErr) |
|||
} |
|||
} |
|||
|
|||
var wg sync.WaitGroup |
|||
subSliceSize := len(docs) / concurrency |
|||
rem := len(docs) % concurrency |
|||
|
|||
for i := 0; i < concurrency; i++ { |
|||
start := i * subSliceSize |
|||
end := start + subSliceSize |
|||
if i == concurrency-1 { |
|||
end += rem |
|||
} |
|||
|
|||
wg.Add(1) |
|||
go func(subSlice []*Document) { |
|||
defer wg.Done() |
|||
for _, doc := range subSlice { |
|||
if ctx.Err() != nil { |
|||
return |
|||
} |
|||
|
|||
sim, err := dotProduct(queryVectors, doc.Embedding) |
|||
if err != nil { |
|||
setSharedErr(fmt.Errorf("无法计算文档 '%s' 的相似度: %w", doc.ID, err)) |
|||
return |
|||
} |
|||
|
|||
nMaxDocs.add(docSim{docID: doc.ID, similarity: sim}) |
|||
} |
|||
}(docs[start:end]) |
|||
} |
|||
|
|||
wg.Wait() |
|||
|
|||
if sharedErr != nil { |
|||
return nil, sharedErr |
|||
} |
|||
|
|||
return nMaxDocs.values(), nil |
|||
} |
|||
|
|||
// 辅助函数:返回两个数中的最小值。
|
|||
func min(a, b int) int { |
|||
if a < b { |
|||
return a |
|||
} |
|||
return b |
|||
} |
@ -0,0 +1,79 @@ |
|||
package vector |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"math" |
|||
) |
|||
|
|||
const isNormalizedPrecisionTolerance = 1e-6 |
|||
|
|||
// cosineSimilarity 计算两个向量的余弦相似度。
|
|||
// 向量在计算前会被归一化。
|
|||
// 结果值表示相似度,值越高表示向量越相似。
|
|||
func cosineSimilarity(a, b []float32) (float32, error) { |
|||
// 向量必须具有相同的长度
|
|||
if len(a) != len(b) { |
|||
return 0, errors.New("向量必须具有相同的长度") |
|||
} |
|||
|
|||
// 归一化向量
|
|||
aNorm := normalizeVector(a) |
|||
bNorm := normalizeVector(b) |
|||
|
|||
// 计算点积
|
|||
dotProduct, err := dotProduct(aNorm, bNorm) |
|||
if err != nil { |
|||
return 0, fmt.Errorf("无法计算点积: %w", err) |
|||
} |
|||
|
|||
return dotProduct, nil |
|||
} |
|||
|
|||
// dotProduct 计算两个向量的点积。
|
|||
// 对于归一化的向量,点积等同于余弦相似度。
|
|||
// 结果值表示相似度,值越高表示向量越相似。
|
|||
func dotProduct(a, b []float32) (float32, error) { |
|||
// 向量必须具有相同的长度
|
|||
if len(a) != len(b) { |
|||
return 0, errors.New("向量必须具有相同的长度") |
|||
} |
|||
|
|||
var dotProduct float32 |
|||
for i := range a { |
|||
dotProduct += a[i] * b[i] |
|||
} |
|||
|
|||
return dotProduct, nil |
|||
} |
|||
|
|||
// normalizeVector 归一化一个浮点数向量。
|
|||
// 归一化是指将向量的每个分量除以向量的模(长度),使得归一化后的向量长度为 1。
|
|||
func normalizeVector(v []float32) []float32 { |
|||
var norm float64 |
|||
for _, val := range v { |
|||
norm += float64(val * val) |
|||
} |
|||
if norm == 0 { |
|||
return v // 避免除以零的情况
|
|||
} |
|||
norm = math.Sqrt(norm) |
|||
|
|||
res := make([]float32, len(v)) |
|||
for i, val := range v { |
|||
res[i] = float32(float64(val) / norm) |
|||
} |
|||
|
|||
return res |
|||
} |
|||
|
|||
// isNormalized 检查向量是否已经归一化。
|
|||
// 如果向量的模接近 1,则认为它是归一化的。
|
|||
func isNormalized(v []float32) bool { |
|||
var sqSum float64 |
|||
for _, val := range v { |
|||
sqSum += float64(val) * float64(val) |
|||
} |
|||
magnitude := math.Sqrt(sqSum) |
|||
return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance |
|||
} |
@ -0,0 +1,14 @@ |
|||
package model |
|||
|
|||
import "gorm.io/gorm" |
|||
|
|||
type VecDoc struct { |
|||
gorm.Model |
|||
Content string `json:"content"` |
|||
FilePath string `json:"file_path" gorm:"not null"` |
|||
ListID int `json:"list_id"` |
|||
} |
|||
|
|||
func (VecDoc) TableName() string { |
|||
return "vec_doc" |
|||
} |
@ -0,0 +1,65 @@ |
|||
package model |
|||
|
|||
import ( |
|||
"fmt" |
|||
"log" |
|||
|
|||
"gorm.io/gorm" |
|||
) |
|||
|
|||
type VecList struct { |
|||
gorm.Model |
|||
FilePath string `json:"file_path" gorm:"not null"` |
|||
Engine string `json:"engine" gorm:"not null"` |
|||
EmbedSize int `json:"embed_size"` |
|||
EmbeddingModel string `json:"model" gorm:"not null"` |
|||
} |
|||
|
|||
func (*VecList) TableName() string { |
|||
return "vec_list" |
|||
} |
|||
|
|||
// BeforeCreate 在插入数据之前检查是否存在相同路径的数据
|
|||
func (v *VecList) BeforeCreate(tx *gorm.DB) error { |
|||
var count int64 |
|||
if err := tx.Model(&VecList{}).Where("file_path = ?", v.FilePath).Count(&count).Error; err != nil { |
|||
return err |
|||
} |
|||
if count > 0 { |
|||
return fmt.Errorf("file path already exists: %s", v.FilePath) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// AfterCreate 在插入数据之后创建虚拟表
|
|||
func (v *VecList) AfterCreate(tx *gorm.DB) error { |
|||
return CreateVirtualTable(tx, v.ID, v.EmbedSize) |
|||
} |
|||
|
|||
// AfterDelete 在删除数据之后删除虚拟表
|
|||
func (v *VecList) AfterDelete(tx *gorm.DB) error { |
|||
// 删除 VecDoc 表中 ListID 对应的所有数据
|
|||
if err := tx.Where("list_id = ?", v.ID).Delete(&VecDoc{}).Error; err != nil { |
|||
return err |
|||
} |
|||
return DropVirtualTable(tx, v.ID) |
|||
} |
|||
|
|||
// CreateVirtualTable 创建虚拟表
|
|||
func CreateVirtualTable(db *gorm.DB, vectorID uint, embeddingSize int) error { |
|||
sql := fmt.Sprintf(` |
|||
CREATE VIRTUAL TABLE IF NOT EXISTS [%d_vec] USING |
|||
vec0( |
|||
document_id TEXT PRIMARY KEY, |
|||
embedding float[%d] distance_metric=cosine |
|||
) |
|||
`, vectorID, embeddingSize) |
|||
log.Printf("sql: %s", sql) |
|||
return db.Exec(sql).Error |
|||
} |
|||
|
|||
// DropVirtualTable 删除虚拟表
|
|||
func DropVirtualTable(db *gorm.DB, vectorID uint) error { |
|||
sql := fmt.Sprintf(`DROP TABLE IF EXISTS [%d_vec]`, vectorID) |
|||
return db.Exec(sql).Error |
|||
} |
@ -1,14 +1,19 @@ |
|||
module vector |
|||
module godovec |
|||
|
|||
go 1.22.5 |
|||
go 1.23.3 |
|||
|
|||
require ( |
|||
github.com/asg017/sqlite-vec-go-bindings v0.0.1-alpha.37 |
|||
github.com/ncruces/go-sqlite3 v0.17.2-0.20240711235451-21de85e849b7 |
|||
github.com/asg017/sqlite-vec-go-bindings v0.1.6 |
|||
github.com/ncruces/go-sqlite3/gormlite v0.21.0 |
|||
gorm.io/gorm v1.25.12 |
|||
) |
|||
|
|||
require ( |
|||
github.com/jinzhu/inflection v1.0.0 // indirect |
|||
github.com/jinzhu/now v1.1.5 // indirect |
|||
github.com/ncruces/go-sqlite3 v0.21.0 // indirect |
|||
github.com/ncruces/julianday v1.0.0 // indirect |
|||
github.com/tetratelabs/wazero v1.7.3 // indirect |
|||
golang.org/x/sys v0.22.0 // indirect |
|||
github.com/tetratelabs/wazero v1.8.2 // indirect |
|||
golang.org/x/sys v0.28.0 // indirect |
|||
golang.org/x/text v0.21.0 // indirect |
|||
) |
|||
|
@ -1,12 +1,20 @@ |
|||
github.com/asg017/sqlite-vec-go-bindings v0.0.1-alpha.37 h1:Gz6YkDCs60k5VwbBPKDfAPPeIBcuaN3qriAozAaIIZI= |
|||
github.com/asg017/sqlite-vec-go-bindings v0.0.1-alpha.37/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= |
|||
github.com/ncruces/go-sqlite3 v0.17.2-0.20240711235451-21de85e849b7 h1:ssM02uUFDfz0V2TMg2du2BjbW9cpOhFJK0kpDN+X768= |
|||
github.com/ncruces/go-sqlite3 v0.17.2-0.20240711235451-21de85e849b7/go.mod h1:FnCyui8SlDoL0mQZ5dTouNo7s7jXS0kJv9lBt1GlM9w= |
|||
github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= |
|||
github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= |
|||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= |
|||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= |
|||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= |
|||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= |
|||
github.com/ncruces/go-sqlite3 v0.21.0 h1:EwKFoy1hHEopN4sFZarmi+McXdbCcbTuLixhEayXVbQ= |
|||
github.com/ncruces/go-sqlite3 v0.21.0/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA= |
|||
github.com/ncruces/go-sqlite3/gormlite v0.21.0 h1:9DsbvW9dS6uxXNFmbrNZixqAXKnIFnLM8oZmKqp8vcI= |
|||
github.com/ncruces/go-sqlite3/gormlite v0.21.0/go.mod h1:rP4JXD6jlpOSsg2Ed++kzJIAZZCIBirVYqIpwaLW88E= |
|||
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= |
|||
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= |
|||
github.com/tetratelabs/wazero v1.7.3 h1:PBH5KVahrt3S2AHgEjKu4u+LlDbbk+nsGE3KLucy6Rw= |
|||
github.com/tetratelabs/wazero v1.7.3/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= |
|||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= |
|||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= |
|||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= |
|||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= |
|||
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4= |
|||
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= |
|||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= |
|||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= |
|||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= |
|||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= |
|||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= |
|||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= |
|||
|
@ -1,5 +0,0 @@ |
|||
package main |
|||
|
|||
func Init() { |
|||
|
|||
} |
@ -1,100 +0,0 @@ |
|||
package main |
|||
|
|||
import ( |
|||
_ "embed" |
|||
"log" |
|||
|
|||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" |
|||
"github.com/ncruces/go-sqlite3" |
|||
) |
|||
|
|||
var Db *sqlite3.Conn |
|||
|
|||
func main() { |
|||
db, err := sqlite3.Open(":memory:") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
stmt.Step() |
|||
|
|||
log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) |
|||
stmt.Close() |
|||
|
|||
err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4])") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
items := map[int][]float32{ |
|||
1: {0.1, 0.1, 0.1, 0.1}, |
|||
2: {0.2, 0.2, 0.2, 0.2}, |
|||
3: {0.3, 0.3, 0.3, 0.3}, |
|||
4: {0.4, 0.4, 0.4, 0.4}, |
|||
5: {0.5, 0.5, 0.5, 0.5}, |
|||
} |
|||
q := []float32{0.3, 0.3, 0.3, 0.3} |
|||
|
|||
stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding) VALUES (?, ?)") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
for id, values := range items { |
|||
v, err := sqlite_vec.SerializeFloat32(values) |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
stmt.BindInt(1, id) |
|||
stmt.BindBlob(2, v) |
|||
err = stmt.Exec() |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
stmt.Reset() |
|||
} |
|||
stmt.Close() |
|||
|
|||
stmt, _, err = db.Prepare(` |
|||
SELECT |
|||
rowid, |
|||
distance |
|||
FROM vec_items |
|||
WHERE embedding MATCH ? |
|||
ORDER BY distance |
|||
LIMIT 3 |
|||
`) |
|||
|
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
query, err := sqlite_vec.SerializeFloat32(q) |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
stmt.BindBlob(1, query) |
|||
|
|||
for stmt.Step() { |
|||
rowid := stmt.ColumnInt64(0) |
|||
distance := stmt.ColumnFloat(1) |
|||
log.Printf("rowid=%d, distance=%f\n", rowid, distance) |
|||
} |
|||
if err := stmt.Err(); err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
err = stmt.Close() |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
err = db.Close() |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
} |
@ -0,0 +1,54 @@ |
|||
package godovec |
|||
|
|||
import ( |
|||
"fmt" |
|||
|
|||
_ "embed" |
|||
|
|||
_ "github.com/asg017/sqlite-vec-go-bindings/ncruces" |
|||
|
|||
"github.com/ncruces/go-sqlite3/gormlite" |
|||
"gorm.io/gorm" |
|||
) |
|||
|
|||
var VecDb *gorm.DB |
|||
|
|||
type VectorList struct { |
|||
ID int `json:"id" gorm:"primaryKey"` |
|||
FilePath string `json:"file_path" gorm:"not null"` |
|||
Engine string `json:"engine" gorm:"not null"` |
|||
EmbeddingModel string `json:"model" gorm:"not null"` |
|||
} |
|||
|
|||
type VectorDoc struct { |
|||
ID int `json:"id" gorm:"primaryKey"` |
|||
Content string `json:"content"` |
|||
FilePath string `json:"file_path" gorm:"not null"` |
|||
ListID int `json:"list_id"` |
|||
} |
|||
|
|||
func main() { |
|||
InitVector() |
|||
} |
|||
func InitVector() error { |
|||
|
|||
db, err := gorm.Open(gormlite.Open("./data.db"), &gorm.Config{}) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to open vector db: %w", err) |
|||
} |
|||
|
|||
// Enable PRAGMAs
|
|||
// - busy_timeout (ms) to prevent db lockups as we're accessing the DB from multiple separate processes in otto8
|
|||
tx := db.Exec(` |
|||
PRAGMA busy_timeout = 5000; |
|||
`) |
|||
if tx.Error != nil { |
|||
return fmt.Errorf("failed to execute pragma busy_timeout: %w", tx.Error) |
|||
} |
|||
err = db.AutoMigrate(&VectorList{}, &VectorDoc{}) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to auto migrate tables: %w", err) |
|||
} |
|||
VecDb = db |
|||
return nil |
|||
} |
Loading…
Reference in new issue