You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

409 lines
12 KiB

package vector
import (
"context"
"errors"
"fmt"
"path/filepath"
"slices"
"sync"
)
// Collection 表示一个文档集合。
// 它还包含一个配置好的嵌入函数,当添加没有嵌入的文档时会使用该函数。
type Collection struct {
Name string
metadata map[string]string
documents map[string]*Document
documentsLock sync.RWMutex
embed EmbeddingFunc
persistDirectory string
compress bool
// ⚠️ 当添加字段时,请考虑在 [DB.Export] 和 [DB.Import] 的持久化结构中添加相应的字段
}
// 我们不导出这个函数,以保持 API 表面最小。
// 用户通过 [Client.CreateCollection] 创建集合。
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
// 复制元数据以避免在创建集合后调用者修改元数据时发生数据竞争。
m := make(map[string]string, len(metadata))
for k, v := range metadata {
m[k] = v
}
c := &Collection{
Name: name,
metadata: m,
documents: make(map[string]*Document),
embed: embed,
}
// 持久化
if dbDir != "" {
safeName := hash2hex(name)
c.persistDirectory = filepath.Join(dbDir, safeName)
c.compress = compress
// 持久化名称和元数据
metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
metadataPath += ".gob"
if c.compress {
metadataPath += ".gz"
}
pc := struct {
Name string
Metadata map[string]string
}{
Name: name,
Metadata: m,
}
err := persistToFile(metadataPath, pc, compress, "")
if err != nil {
return nil, fmt.Errorf("无法持久化集合元数据: %w", err)
}
}
return c, nil
}
// 添加嵌入到数据存储中。
//
// - ids: 要添加的嵌入的 ID
// - embeddings: 要添加的嵌入。如果为 nil,则基于内容使用集合的嵌入函数计算嵌入。可选。
// - metadatas: 与嵌入关联的元数据。查询时可以过滤这些元数据。可选。
// - contents: 与嵌入关联的内容。
//
// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
}
// AddConcurrently 类似于 Add,但并发地添加嵌入。
// 这在没有传递任何嵌入时特别有用,因为需要创建嵌入。
// 出现错误时,取消所有并发操作并返回错误。
//
// 这是一个类似于 Chroma 的方法。对于更符合 Go 风格的方法,请参见 [AddDocuments]。
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
if len(ids) == 0 {
return errors.New("ids 为空")
}
if len(embeddings) == 0 && len(contents) == 0 {
return errors.New("必须填写 embeddings 或 contents")
}
if len(embeddings) != 0 {
if len(embeddings) != len(ids) {
return errors.New("ids 和 embeddings 的长度必须相同")
}
} else {
// 分配空切片以便稍后通过索引访问
embeddings = make([][]float32, len(ids))
}
if len(metadatas) != 0 {
if len(ids) != len(metadatas) {
return errors.New("当 metadatas 不为空时,其长度必须与 ids 相同")
}
} else {
// 分配空切片以便稍后通过索引访问
metadatas = make([]map[string]string, len(ids))
}
if len(contents) != 0 {
if len(contents) != len(ids) {
return errors.New("ids 和 contents 的长度必须相同")
}
} else {
// 分配空切片以便稍后通过索引访问
contents = make([]string, len(ids))
}
if concurrency < 1 {
return errors.New("并发数必须至少为 1")
}
// 将 Chroma 风格的参数转换为文档切片
docs := make([]Document, 0, len(ids))
for i, id := range ids {
docs = append(docs, Document{
ID: id,
Metadata: metadatas[i],
Embedding: embeddings[i],
Content: contents[i],
})
}
return c.AddDocuments(ctx, docs, concurrency)
}
// AddDocuments 使用指定的并发数将文档添加到集合中。
// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。
// 出现错误时,取消所有并发操作并返回错误。
func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
if len(documents) == 0 {
// TODO: 这是否应为无操作(no-op)?
return errors.New("documents 切片为空")
}
if concurrency < 1 {
return errors.New("并发数必须至少为 1")
}
// 对于其他验证,我们依赖于 AddDocument。
var sharedErr error
sharedErrLock := sync.Mutex{}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
setSharedErr := func(err error) {
sharedErrLock.Lock()
defer sharedErrLock.Unlock()
// 另一个 goroutine 可能已经设置了错误。
if sharedErr == nil {
sharedErr = err
// 取消所有其他 goroutine 的操作。
cancel(sharedErr)
}
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for _, doc := range documents {
wg.Add(1)
go func(doc Document) {
defer wg.Done()
// 如果另一个 goroutine 已经失败,则不开始。
if ctx.Err() != nil {
return
}
// 等待直到 $concurrency 个其他 goroutine 正在创建文档。
semaphore <- struct{}{}
defer func() { <-semaphore }()
err := c.AddDocument(ctx, doc)
if err != nil {
setSharedErr(fmt.Errorf("无法添加文档 '%s': %w", doc.ID, err))
return
}
}(doc)
}
wg.Wait()
return sharedErr
}
// AddDocument 将文档添加到集合中。
// 如果文档没有嵌入,则使用集合的嵌入函数创建嵌入。
func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
if doc.ID == "" {
return errors.New("文档 ID 为空")
}
if len(doc.Embedding) == 0 && doc.Content == "" {
return errors.New("必须填写文档的 embedding 或 content")
}
// 复制元数据以避免在创建文档后调用者修改元数据时发生数据竞争。
m := make(map[string]string, len(doc.Metadata))
for k, v := range doc.Metadata {
m[k] = v
}
// 如果嵌入不存在,则创建嵌入,否则如果需要则规范化
if len(doc.Embedding) == 0 {
embedding, err := c.embed(ctx, doc.Content)
if err != nil {
return fmt.Errorf("无法创建文档的嵌入: %w", err)
}
doc.Embedding = embedding
} else {
if !isNormalized(doc.Embedding) {
doc.Embedding = normalizeVector(doc.Embedding)
}
}
c.documentsLock.Lock()
// 我们不使用 defer 解锁,因为我们希望尽早解锁。
c.documents[doc.ID] = &doc
c.documentsLock.Unlock()
// 持久化文档
if c.persistDirectory != "" {
docPath := c.getDocPath(doc.ID)
err := persistToFile(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("无法将文档持久化到 %q: %w", docPath, err)
}
}
return nil
}
// Delete 从集合中删除文档。
//
// - where: 元数据的条件过滤。可选。
// - whereDocument: 文档的条件过滤。可选。
// - ids: 要删除的文档的 ID。如果为空,则删除所有文档。
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {
// 必须至少有一个 where、whereDocument 或 ids
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
return fmt.Errorf("必须至少有一个 where、whereDocument 或 ids")
}
if len(c.documents) == 0 {
return nil
}
for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return errors.New("不支持的 whereDocument 操作符")
}
}
var docIDs []string
c.documentsLock.Lock()
defer c.documentsLock.Unlock()
if where != nil || whereDocument != nil {
// 元数据 + 内容过滤
filteredDocs := filterDocs(c.documents, where, whereDocument)
for _, doc := range filteredDocs {
docIDs = append(docIDs, doc.ID)
}
} else {
docIDs = ids
}
// 如果没有剩余的文档,则不执行操作
if len(docIDs) == 0 {
return nil
}
for _, docID := range docIDs {
delete(c.documents, docID)
// 从磁盘删除文档
if c.persistDirectory != "" {
docPath := c.getDocPath(docID)
err := removeFile(docPath)
if err != nil {
return fmt.Errorf("无法删除文档 %q: %w", docPath, err)
}
}
}
return nil
}
// Count 返回集合中的文档数量。
func (c *Collection) Count() int {
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
return len(c.documents)
}
// Result 表示查询结果中的单个结果。
type Result struct {
ID string
Metadata map[string]string
Embedding []float32
Content string
// 查询与文档之间的余弦相似度。
// 值越高,文档与查询越相似。
// 值的范围是 [-1, 1]。
Similarity float32
}
// 在集合上执行详尽的最近邻搜索。
//
// - queryText: 要搜索的文本。其嵌入将使用集合的嵌入函数创建。
// - nResults: 要返回的结果数量。必须大于 0。
// - where: 元数据的条件过滤。可选。
// - whereDocument: 文档的条件过滤。可选。
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if queryText == "" {
return nil, errors.New("queryText 为空")
}
queryVectors, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("无法创建查询的嵌入: %w", err)
}
return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
}
// 在集合上执行详尽的最近邻搜索。
//
// - queryEmbedding: 要搜索的查询的嵌入。必须使用与集合中文档嵌入相同的嵌入模型创建。
// - nResults: 要返回的结果数量。必须大于 0。
// - where: 元数据的条件过滤。可选。
// - whereDocument: 文档的条件过滤。可选。
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if len(queryEmbedding) == 0 {
return nil, errors.New("queryEmbedding 为空")
}
if nResults <= 0 {
return nil, errors.New("nResults 必须大于 0")
}
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
// if nResults > len(c.documents) {
// return nil, errors.New("nResults 必须小于或等于集合中的文档数量")
// }
if len(c.documents) == 0 {
return nil, nil
}
// 验证 whereDocument 操作符
for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return nil, errors.New("不支持的操作符")
}
}
// 根据元数据和内容过滤文档
filteredDocs := filterDocs(c.documents, where, whereDocument)
// 如果过滤器删除了所有文档,则不继续
if len(filteredDocs) == 0 {
return nil, nil
}
// 对于剩余的文档,获取最相似的文档。
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults)
if err != nil {
return nil, fmt.Errorf("无法获取最相似的文档: %w", err)
}
length := len(nMaxDocs)
if length > nResults {
length = nResults
}
res := make([]Result, 0, length)
for i := 0; i < length; i++ {
doc := c.documents[nMaxDocs[i].docID]
res = append(res, Result{
ID: nMaxDocs[i].docID,
Metadata: doc.Metadata,
Embedding: doc.Embedding,
Content: doc.Content,
Similarity: nMaxDocs[i].similarity,
})
}
// 返回前 nResults 个结果
return res, nil
}
// getDocPath 生成文档文件的路径。
func (c *Collection) getDocPath(docID string) string {
safeID := hash2hex(docID)
docPath := filepath.Join(c.persistDirectory, safeID)
docPath += ".gob"
if c.compress {
docPath += ".gz"
}
return docPath
}