You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

125 lines
3.7 KiB

package vector
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"sync"
)
const BaseURLOpenAI = "https://api.openai.com/v1"
type EmbeddingModelOpenAI string
const (
EmbeddingModelOpenAI2Ada EmbeddingModelOpenAI = "text-embedding-ada-002"
EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small"
EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large"
)
type openAIResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
// NewEmbeddingFuncDefault 返回一个函数,使用 OpenAI 的 "text-embedding-3-small" 模型通过 API 创建文本嵌入向量。
// 该模型支持的最大文本长度为 8191 个标记。
// API 密钥从环境变量 "OPENAI_API_KEY" 中读取。
func NewEmbeddingFuncDefault() EmbeddingFunc {
apiKey := os.Getenv("OPENAI_API_KEY")
return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small)
}
// NewEmbeddingFuncOpenAI 返回一个函数,使用 OpenAI API 创建文本嵌入向量。
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
// OpenAI 嵌入向量已归一化
normalized := true
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
}
// NewEmbeddingFuncOpenAICompat 返回一个函数,使用兼容 OpenAI 的 API 创建文本嵌入向量。
// 例如:
// - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service
// - LitLLM: https://github.com/BerriAI/litellm
// - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md
//
// `normalized` 参数表示嵌入模型返回的向量是否已经归一化。如果为 nil,则会在首次请求时自动检测(有小概率向量恰好长度为 1)。
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
client := &http.Client{}
var checkedNormalized bool
checkNormalized := sync.Once{}
return func(ctx context.Context, text string) ([]float32, error) {
// 准备请求体
reqBody, err := json.Marshal(map[string]string{
"input": text,
"model": model,
})
if err != nil {
return nil, fmt.Errorf("无法序列化请求体: %w", err)
}
// 创建带有上下文的请求以支持超时
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("无法创建请求: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("无法发送请求: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, errors.New("嵌入 API 返回错误响应: " + resp.Status)
}
// 读取并解码响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("无法读取响应体: %w", err)
}
var embeddingResponse openAIResponse
err = json.Unmarshal(body, &embeddingResponse)
if err != nil {
return nil, fmt.Errorf("无法反序列化响应体: %w", err)
}
// 检查响应中是否包含嵌入向量
if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 {
return nil, errors.New("响应中未找到嵌入向量")
}
v := embeddingResponse.Data[0].Embedding
if normalized != nil {
if *normalized {
return v, nil
}
return normalizeVector(v), nil
}
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}
return v, nil
}
}