mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
409 lines
12 KiB
409 lines
12 KiB
package vector
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"path/filepath"
|
|
"slices"
|
|
"sync"
|
|
)
|
|
|
|
// Collection 表示一个文档集合。
|
|
// 它还包含一个配置好的嵌入函数,当添加没有嵌入的文档时会使用该函数。
|
|
type Collection struct {
|
|
Name string
|
|
|
|
metadata map[string]string
|
|
documents map[string]*Document
|
|
documentsLock sync.RWMutex
|
|
embed EmbeddingFunc
|
|
|
|
persistDirectory string
|
|
compress bool
|
|
|
|
// ⚠️ 当添加字段时,请考虑在 [DB.Export] 和 [DB.Import] 的持久化结构中添加相应的字段
|
|
}
|
|
|
|
// 我们不导出这个函数,以保持 API 表面最小。
|
|
// 用户通过 [Client.CreateCollection] 创建集合。
|
|
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
|
|
// 复制元数据以避免在创建集合后调用者修改元数据时发生数据竞争。
|
|
m := make(map[string]string, len(metadata))
|
|
for k, v := range metadata {
|
|
m[k] = v
|
|
}
|
|
|
|
c := &Collection{
|
|
Name: name,
|
|
|
|
metadata: m,
|
|
documents: make(map[string]*Document),
|
|
embed: embed,
|
|
}
|
|
|
|
// 持久化
|
|
if dbDir != "" {
|
|
safeName := hash2hex(name)
|
|
c.persistDirectory = filepath.Join(dbDir, safeName)
|
|
c.compress = compress
|
|
// 持久化名称和元数据
|
|
metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
|
|
metadataPath += ".gob"
|
|
if c.compress {
|
|
metadataPath += ".gz"
|
|
}
|
|
pc := struct {
|
|
Name string
|
|
Metadata map[string]string
|
|
}{
|
|
Name: name,
|
|
Metadata: m,
|
|
}
|
|
err := persistToFile(metadataPath, pc, compress, "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法持久化集合元数据: %w", err)
|
|
}
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
// 添加嵌入到数据存储中。
|
|
//
|
|
// - ids: 要添加的嵌入的 ID
|
|
// - embeddings: 要添加的嵌入。如果为 nil,则基于内容使用集合的嵌入函数计算嵌入。可选。
|
|
// - metadatas: 与嵌入关联的元数据。查询时可以过滤这些元数据。可选。
|
|
// - contents: 与嵌入关联的内容。
|
|
//
|
|
// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。
|
|
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
|
|
return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
|
|
}
|
|
|
|
// AddConcurrently 类似于 Add,但并发地添加嵌入。
|
|
// 这在没有传递任何嵌入时特别有用,因为需要创建嵌入。
|
|
// 出现错误时,取消所有并发操作并返回错误。
|
|
//
|
|
// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。
|
|
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
|
|
if len(ids) == 0 {
|
|
return errors.New("ids 为空")
|
|
}
|
|
if len(embeddings) == 0 && len(contents) == 0 {
|
|
return errors.New("必须填写 embeddings 或 contents")
|
|
}
|
|
if len(embeddings) != 0 {
|
|
if len(embeddings) != len(ids) {
|
|
return errors.New("ids 和 embeddings 的长度必须相同")
|
|
}
|
|
} else {
|
|
// 分配空切片以便稍后通过索引访问
|
|
embeddings = make([][]float32, len(ids))
|
|
}
|
|
if len(metadatas) != 0 {
|
|
if len(ids) != len(metadatas) {
|
|
return errors.New("当 metadatas 不为空时,其长度必须与 ids 相同")
|
|
}
|
|
} else {
|
|
// 分配空切片以便稍后通过索引访问
|
|
metadatas = make([]map[string]string, len(ids))
|
|
}
|
|
if len(contents) != 0 {
|
|
if len(contents) != len(ids) {
|
|
return errors.New("ids 和 contents 的长度必须相同")
|
|
}
|
|
} else {
|
|
// 分配空切片以便稍后通过索引访问
|
|
contents = make([]string, len(ids))
|
|
}
|
|
if concurrency < 1 {
|
|
return errors.New("并发数必须至少为 1")
|
|
}
|
|
|
|
// 将 Chroma 风格的参数转换为文档切片
|
|
docs := make([]Document, 0, len(ids))
|
|
for i, id := range ids {
|
|
docs = append(docs, Document{
|
|
ID: id,
|
|
Metadata: metadatas[i],
|
|
Embedding: embeddings[i],
|
|
Content: contents[i],
|
|
})
|
|
}
|
|
|
|
return c.AddDocuments(ctx, docs, concurrency)
|
|
}
|
|
|
|
// AddDocuments 使用指定的并发数将文档添加到集合中。
|
|
// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。
|
|
// 出现错误时,取消所有并发操作并返回错误。
|
|
func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
|
|
if len(documents) == 0 {
|
|
// TODO: 这是否应为无操作(no-op)?
|
|
return errors.New("documents 切片为空")
|
|
}
|
|
if concurrency < 1 {
|
|
return errors.New("并发数必须至少为 1")
|
|
}
|
|
// 对于其他验证,我们依赖于 AddDocument。
|
|
|
|
var sharedErr error
|
|
sharedErrLock := sync.Mutex{}
|
|
ctx, cancel := context.WithCancelCause(ctx)
|
|
defer cancel(nil)
|
|
setSharedErr := func(err error) {
|
|
sharedErrLock.Lock()
|
|
defer sharedErrLock.Unlock()
|
|
// 另一个 goroutine 可能已经设置了错误。
|
|
if sharedErr == nil {
|
|
sharedErr = err
|
|
// 取消所有其他 goroutine 的操作。
|
|
cancel(sharedErr)
|
|
}
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
semaphore := make(chan struct{}, concurrency)
|
|
for _, doc := range documents {
|
|
wg.Add(1)
|
|
go func(doc Document) {
|
|
defer wg.Done()
|
|
|
|
// 如果另一个 goroutine 已经失败,则不开始。
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
// 等待直到 $concurrency 个其他 goroutine 正在创建文档。
|
|
semaphore <- struct{}{}
|
|
defer func() { <-semaphore }()
|
|
|
|
err := c.AddDocument(ctx, doc)
|
|
if err != nil {
|
|
setSharedErr(fmt.Errorf("无法添加文档 '%s': %w", doc.ID, err))
|
|
return
|
|
}
|
|
}(doc)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
return sharedErr
|
|
}
|
|
|
|
// AddDocument 将文档添加到集合中。
|
|
// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。
|
|
func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
|
|
if doc.ID == "" {
|
|
return errors.New("文档 ID 为空")
|
|
}
|
|
if len(doc.Embedding) == 0 && doc.Content == "" {
|
|
return errors.New("必须填写文档的 embedding 或 content")
|
|
}
|
|
|
|
// 复制元数据以避免在创建文档后调用者修改元数据时发生数据竞争。
|
|
m := make(map[string]string, len(doc.Metadata))
|
|
for k, v := range doc.Metadata {
|
|
m[k] = v
|
|
}
|
|
|
|
// 如果嵌入不存在,则创建嵌入,否则如果需要则规范化
|
|
if len(doc.Embedding) == 0 {
|
|
embedding, err := c.embed(ctx, doc.Content)
|
|
if err != nil {
|
|
return fmt.Errorf("无法创建文档的嵌入: %w", err)
|
|
}
|
|
doc.Embedding = embedding
|
|
} else {
|
|
if !isNormalized(doc.Embedding) {
|
|
doc.Embedding = normalizeVector(doc.Embedding)
|
|
}
|
|
}
|
|
|
|
c.documentsLock.Lock()
|
|
// 我们不使用 defer 解锁,因为我们希望尽早解锁。
|
|
c.documents[doc.ID] = &doc
|
|
c.documentsLock.Unlock()
|
|
|
|
// 持久化文档
|
|
if c.persistDirectory != "" {
|
|
docPath := c.getDocPath(doc.ID)
|
|
err := persistToFile(docPath, doc, c.compress, "")
|
|
if err != nil {
|
|
return fmt.Errorf("无法将文档持久化到 %q: %w", docPath, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete 从集合中删除文档。
|
|
//
|
|
// - where: 元数据的条件过滤。可选。
|
|
// - whereDocument: 文档的条件过滤。可选。
|
|
// - ids: 要删除的文档的 ID。如果为空,则删除所有文档。
|
|
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {
|
|
// 必须至少有一个 where、whereDocument 或 ids
|
|
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
|
|
return fmt.Errorf("必须至少有一个 where、whereDocument 或 ids")
|
|
}
|
|
|
|
if len(c.documents) == 0 {
|
|
return nil
|
|
}
|
|
|
|
for k := range whereDocument {
|
|
if !slices.Contains(supportedFilters, k) {
|
|
return errors.New("不支持的 whereDocument 操作符")
|
|
}
|
|
}
|
|
|
|
var docIDs []string
|
|
|
|
c.documentsLock.Lock()
|
|
defer c.documentsLock.Unlock()
|
|
|
|
if where != nil || whereDocument != nil {
|
|
// 元数据 + 内容过滤
|
|
filteredDocs := filterDocs(c.documents, where, whereDocument)
|
|
for _, doc := range filteredDocs {
|
|
docIDs = append(docIDs, doc.ID)
|
|
}
|
|
} else {
|
|
docIDs = ids
|
|
}
|
|
|
|
// 如果没有剩余的文档,则不执行操作
|
|
if len(docIDs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
for _, docID := range docIDs {
|
|
delete(c.documents, docID)
|
|
|
|
// 从磁盘删除文档
|
|
if c.persistDirectory != "" {
|
|
docPath := c.getDocPath(docID)
|
|
err := removeFile(docPath)
|
|
if err != nil {
|
|
return fmt.Errorf("无法删除文档 %q: %w", docPath, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Count 返回集合中的文档数量。
|
|
func (c *Collection) Count() int {
|
|
c.documentsLock.RLock()
|
|
defer c.documentsLock.RUnlock()
|
|
return len(c.documents)
|
|
}
|
|
|
|
// Result 表示查询结果中的单个结果。
|
|
type Result struct {
|
|
ID string
|
|
Metadata map[string]string
|
|
Embedding []float32
|
|
Content string
|
|
|
|
// 查询与文档之间的余弦相似度。
|
|
// 值越高,文档与查询越相似。
|
|
// 值的范围是 [-1, 1]。
|
|
Similarity float32
|
|
}
|
|
|
|
// 在集合上执行详尽的最近邻搜索。
|
|
//
|
|
// - queryText: 要搜索的文本。其嵌入将使用集合的嵌入函数创建。
|
|
// - nResults: 要返回的结果数量。必须大于 0。
|
|
// - where: 元数据的条件过滤。可选。
|
|
// - whereDocument: 文档的条件过滤。可选。
|
|
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
|
|
if queryText == "" {
|
|
return nil, errors.New("queryText 为空")
|
|
}
|
|
|
|
queryVectors, err := c.embed(ctx, queryText)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法创建查询的嵌入: %w", err)
|
|
}
|
|
|
|
return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
|
|
}
|
|
|
|
// 在集合上执行详尽的最近邻搜索。
|
|
//
|
|
// - queryEmbedding: 要搜索的查询的嵌入。必须使用与集合中文档嵌入相同的嵌入模型创建。
|
|
// - nResults: 要返回的结果数量。必须大于 0。
|
|
// - where: 元数据的条件过滤。可选。
|
|
// - whereDocument: 文档的条件过滤。可选。
|
|
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
|
|
if len(queryEmbedding) == 0 {
|
|
return nil, errors.New("queryEmbedding 为空")
|
|
}
|
|
if nResults <= 0 {
|
|
return nil, errors.New("nResults 必须大于 0")
|
|
}
|
|
c.documentsLock.RLock()
|
|
defer c.documentsLock.RUnlock()
|
|
// if nResults > len(c.documents) {
|
|
// return nil, errors.New("nResults 必须小于或等于集合中的文档数量")
|
|
// }
|
|
|
|
if len(c.documents) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// 验证 whereDocument 操作符
|
|
for k := range whereDocument {
|
|
if !slices.Contains(supportedFilters, k) {
|
|
return nil, errors.New("不支持的操作符")
|
|
}
|
|
}
|
|
|
|
// 根据元数据和内容过滤文档
|
|
filteredDocs := filterDocs(c.documents, where, whereDocument)
|
|
|
|
// 如果过滤器删除了所有文档,则不继续
|
|
if len(filteredDocs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// 对于剩余的文档,获取最相似的文档。
|
|
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法获取最相似的文档: %w", err)
|
|
}
|
|
length := len(nMaxDocs)
|
|
if length > nResults {
|
|
length = nResults
|
|
}
|
|
res := make([]Result, 0, length)
|
|
for i := 0; i < length; i++ {
|
|
doc := c.documents[nMaxDocs[i].docID]
|
|
res = append(res, Result{
|
|
ID: nMaxDocs[i].docID,
|
|
Metadata: doc.Metadata,
|
|
Embedding: doc.Embedding,
|
|
Content: doc.Content,
|
|
Similarity: nMaxDocs[i].similarity,
|
|
})
|
|
}
|
|
|
|
// 返回前 nResults 个结果
|
|
return res, nil
|
|
}
|
|
|
|
// getDocPath 生成文档文件的路径。
|
|
func (c *Collection) getDocPath(docID string) string {
|
|
safeID := hash2hex(docID)
|
|
docPath := filepath.Join(c.persistDirectory, safeID)
|
|
docPath += ".gob"
|
|
if c.compress {
|
|
docPath += ".gz"
|
|
}
|
|
return docPath
|
|
}
|
|
|