mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
3.7 KiB
125 lines
3.7 KiB
package vector
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
)
|
|
|
|
const BaseURLOpenAI = "https://api.openai.com/v1"
|
|
|
|
type EmbeddingModelOpenAI string
|
|
|
|
const (
|
|
EmbeddingModelOpenAI2Ada EmbeddingModelOpenAI = "text-embedding-ada-002"
|
|
EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small"
|
|
EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large"
|
|
)
|
|
|
|
type openAIResponse struct {
|
|
Data []struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
// NewEmbeddingFuncDefault 返回一个函数,使用 OpenAI 的 "text-embedding-3-small" 模型通过 API 创建文本嵌入向量。
|
|
// 该模型支持的最大文本长度为 8191 个标记。
|
|
// API 密钥从环境变量 "OPENAI_API_KEY" 中读取。
|
|
func NewEmbeddingFuncDefault() EmbeddingFunc {
|
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
|
return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small)
|
|
}
|
|
|
|
// NewEmbeddingFuncOpenAI 返回一个函数,使用 OpenAI API 创建文本嵌入向量。
|
|
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
|
|
// OpenAI 嵌入向量已归一化
|
|
normalized := true
|
|
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
|
|
}
|
|
|
|
// NewEmbeddingFuncOpenAICompat 返回一个函数,使用兼容 OpenAI 的 API 创建文本嵌入向量。
|
|
// 例如:
|
|
// - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
|
// - LitLLM: https://github.com/BerriAI/litellm
|
|
// - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md
|
|
//
|
|
// `normalized` 参数表示嵌入模型返回的向量是否已经归一化。如果为 nil,则会在首次请求时自动检测(有小概率向量恰好长度为 1)。
|
|
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
|
|
client := &http.Client{}
|
|
|
|
var checkedNormalized bool
|
|
checkNormalized := sync.Once{}
|
|
|
|
return func(ctx context.Context, text string) ([]float32, error) {
|
|
// 准备请求体
|
|
reqBody, err := json.Marshal(map[string]string{
|
|
"input": text,
|
|
"model": model,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法序列化请求体: %w", err)
|
|
}
|
|
|
|
// 创建带有上下文的请求以支持超时
|
|
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法创建请求: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
// 发送请求
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法发送请求: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 检查响应状态
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, errors.New("嵌入 API 返回错误响应: " + resp.Status)
|
|
}
|
|
|
|
// 读取并解码响应体
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法读取响应体: %w", err)
|
|
}
|
|
var embeddingResponse openAIResponse
|
|
err = json.Unmarshal(body, &embeddingResponse)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法反序列化响应体: %w", err)
|
|
}
|
|
|
|
// 检查响应中是否包含嵌入向量
|
|
if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 {
|
|
return nil, errors.New("响应中未找到嵌入向量")
|
|
}
|
|
|
|
v := embeddingResponse.Data[0].Embedding
|
|
if normalized != nil {
|
|
if *normalized {
|
|
return v, nil
|
|
}
|
|
return normalizeVector(v), nil
|
|
}
|
|
checkNormalized.Do(func() {
|
|
if isNormalized(v) {
|
|
checkedNormalized = true
|
|
} else {
|
|
checkedNormalized = false
|
|
}
|
|
})
|
|
if !checkedNormalized {
|
|
v = normalizeVector(v)
|
|
}
|
|
|
|
return v, nil
|
|
}
|
|
}
|
|
|