You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

412 lines
10 KiB

package vector
import (
"context"
"errors"
"fmt"
"godo/libs"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
)
// EmbeddingFunc 是一个为给定文本创建嵌入的函数。
// 默认使用 OpenAI 的 "text-embedding-3-small" 模型。
// 该函数必须返回一个已归一化的向量。
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
// DB 包含多个集合,每个集合包含多个文档。
type DB struct {
collections map[string]*Collection
collectionsLock sync.RWMutex
persistDirectory string
compress bool
}
// NewDB 创建一个新的内存中的数据库。
func NewDB() *DB {
return &DB{
collections: make(map[string]*Collection),
}
}
// NewPersistentDB 创建一个新的持久化的数据库。
// 如果路径为空,默认为 "./godoos/data/godoDB"。
// 如果 compress 为 true,则文件将使用 gzip 压缩。
func NewPersistentDB(path string, compress bool) (*DB, error) {
homeDir, err := libs.GetAppDir()
if err != nil {
return nil, err
}
if path == "" {
path = filepath.Join(homeDir, "data", "godoDB")
} else {
path = filepath.Clean(path)
}
ext := ".gob"
if compress {
ext += ".gz"
}
db := &DB{
collections: make(map[string]*Collection),
persistDirectory: path,
compress: compress,
}
fi, err := os.Stat(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
err := os.MkdirAll(path, 0o700)
if err != nil {
return nil, fmt.Errorf("无法创建持久化目录: %w", err)
}
return db, nil
}
return nil, fmt.Errorf("无法获取持久化目录信息: %w", err)
} else if !fi.IsDir() {
return nil, fmt.Errorf("路径不是目录: %s", path)
}
dirEntries, err := os.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("无法读取持久化目录: %w", err)
}
for _, dirEntry := range dirEntries {
if !dirEntry.IsDir() {
continue
}
collectionPath := filepath.Join(path, dirEntry.Name())
collectionDirEntries, err := os.ReadDir(collectionPath)
if err != nil {
return nil, fmt.Errorf("无法读取集合目录: %w", err)
}
c := &Collection{
documents: make(map[string]*Document),
persistDirectory: collectionPath,
compress: compress,
}
for _, collectionDirEntry := range collectionDirEntries {
if collectionDirEntry.IsDir() {
continue
}
fPath := filepath.Join(collectionPath, collectionDirEntry.Name())
if collectionDirEntry.Name() == metadataFileName+ext {
pc := struct {
Name string
Metadata map[string]string
}{}
err := readFromFile(fPath, &pc, "")
if err != nil {
return nil, fmt.Errorf("无法读取集合元数据: %w", err)
}
c.Name = pc.Name
c.metadata = pc.Metadata
} else if strings.HasSuffix(collectionDirEntry.Name(), ext) {
d := &Document{}
err := readFromFile(fPath, d, "")
if err != nil {
return nil, fmt.Errorf("无法读取文档: %w", err)
}
c.documents[d.ID] = d
}
}
if c.Name == "" && len(c.documents) == 0 {
continue
}
if c.Name == "" {
return nil, fmt.Errorf("未找到集合元数据文件: %s", collectionPath)
}
db.collections[c.Name] = c
}
return db, nil
}
// ImportFromFile 从给定路径的文件导入数据库。
func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("文件路径为空")
}
if encryptionKey != "" && len(encryptionKey) != 32 {
return errors.New("加密密钥必须为 32 字节长")
}
fi, err := os.Stat(filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("文件不存在: %s", filePath)
}
return fmt.Errorf("无法获取文件信息: %w", err)
} else if fi.IsDir() {
return fmt.Errorf("路径是目录: %s", filePath)
}
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()
err = readFromFile(filePath, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("无法读取文件: %w", err)
}
for _, pc := range persistenceDB.Collections {
c := &Collection{
Name: pc.Name,
metadata: pc.Metadata,
documents: pc.Documents,
}
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
}
db.collections[c.Name] = c
}
return nil
}
// ImportFromReader 从 reader 导入数据库。
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
if encryptionKey != "" && len(encryptionKey) != 32 {
return errors.New("加密密钥必须为 32 字节长")
}
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()
err := readFromReader(reader, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("无法读取流: %w", err)
}
for _, pc := range persistenceDB.Collections {
c := &Collection{
Name: pc.Name,
metadata: pc.Metadata,
documents: pc.Documents,
}
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
}
db.collections[c.Name] = c
}
return nil
}
// ExportToFile 将数据库导出到给定路径的文件。
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
if filePath == "" {
filePath = "./gododb.gob"
if compress {
filePath += ".gz"
}
if encryptionKey != "" {
filePath += ".enc"
}
}
if encryptionKey != "" && len(encryptionKey) != 32 {
return errors.New("加密密钥必须为 32 字节长")
}
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}
db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()
for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}
err := persistToFile(filePath, persistenceDB, compress, encryptionKey)
if err != nil {
return fmt.Errorf("无法导出数据库: %w", err)
}
return nil
}
// ExportToWriter 将数据库导出到 writer。
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
if encryptionKey != "" && len(encryptionKey) != 32 {
return errors.New("加密密钥必须为 32 字节长")
}
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}
db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()
for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}
err := persistToWriter(writer, persistenceDB, compress, encryptionKey)
if err != nil {
return fmt.Errorf("无法导出数据库: %w", err)
}
return nil
}
// CreateCollection 创建具有给定名称和元数据的新集合。
func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
if name == "" {
return nil, errors.New("集合名称为空")
}
if embeddingFunc == nil {
embeddingFunc = NewEmbeddingFuncDefault()
}
collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory, db.compress)
if err != nil {
return nil, fmt.Errorf("无法创建集合: %w", err)
}
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()
db.collections[name] = collection
return collection, nil
}
// ListCollections 返回数据库中的所有集合。
func (db *DB) ListCollections() map[string]*Collection {
db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()
res := make(map[string]*Collection, len(db.collections))
for k, v := range db.collections {
res[k] = v
}
return res
}
// GetCollection 返回具有给定名称的集合。
func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection {
db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()
c, ok := db.collections[name]
if !ok {
return nil
}
if c.embed == nil {
if embeddingFunc == nil {
c.embed = NewEmbeddingFuncDefault()
} else {
c.embed = embeddingFunc
}
}
return c
}
// GetOrCreateCollection 返回数据库中已有的集合,或创建一个新的集合。
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
collection := db.GetCollection(name, embeddingFunc)
if collection == nil {
var err error
collection, err = db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
return nil, fmt.Errorf("无法创建集合: %w", err)
}
}
return collection, nil
}
// DeleteCollection 删除具有给定名称的集合。
func (db *DB) DeleteCollection(name string) error {
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()
col, ok := db.collections[name]
if !ok {
return nil
}
if db.persistDirectory != "" {
collectionPath := col.persistDirectory
err := os.RemoveAll(collectionPath)
if err != nil {
return fmt.Errorf("无法删除集合目录: %w", err)
}
}
delete(db.collections, name)
return nil
}
// Reset 从数据库中移除所有集合。
func (db *DB) Reset() error {
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()
if db.persistDirectory != "" {
err := os.RemoveAll(db.persistDirectory)
if err != nil {
return fmt.Errorf("无法删除持久化目录: %w", err)
}
err = os.MkdirAll(db.persistDirectory, 0o700)
if err != nil {
return fmt.Errorf("无法重新创建持久化目录: %w", err)
}
}
db.collections = make(map[string]*Collection)
return nil
}