package vector

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"os"
	"sync"
)

const BaseURLOpenAI = "https://api.openai.com/v1"

type EmbeddingModelOpenAI string

const (
	EmbeddingModelOpenAI2Ada   EmbeddingModelOpenAI = "text-embedding-ada-002"
	EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small"
	EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large"
)

type openAIResponse struct {
	Data []struct {
		Embedding []float32 `json:"embedding"`
	} `json:"data"`
}

// NewEmbeddingFuncDefault 返回一个函数,使用 OpenAI 的 "text-embedding-3-small" 模型通过 API 创建文本嵌入向量。
// 该模型支持的最大文本长度为 8191 个标记。
// API 密钥从环境变量 "OPENAI_API_KEY" 中读取。
func NewEmbeddingFuncDefault() EmbeddingFunc {
	apiKey := os.Getenv("OPENAI_API_KEY")
	return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small)
}

// NewEmbeddingFuncOpenAI 返回一个函数,使用 OpenAI API 创建文本嵌入向量。
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
	// OpenAI 嵌入向量已归一化
	normalized := true
	return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
}

// NewEmbeddingFuncOpenAICompat 返回一个函数,使用兼容 OpenAI 的 API 创建文本嵌入向量。
// 例如:
//   - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service
//   - LitLLM: https://github.com/BerriAI/litellm
//   - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md
//
// `normalized` 参数表示嵌入模型返回的向量是否已经归一化。如果为 nil,则会在首次请求时自动检测(有小概率向量恰好长度为 1)。
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
	client := &http.Client{}

	var checkedNormalized bool
	checkNormalized := sync.Once{}

	return func(ctx context.Context, text string) ([]float32, error) {
		// 准备请求体
		reqBody, err := json.Marshal(map[string]string{
			"input": text,
			"model": model,
		})
		if err != nil {
			return nil, fmt.Errorf("无法序列化请求体: %w", err)
		}

		// 创建带有上下文的请求以支持超时
		req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody))
		if err != nil {
			return nil, fmt.Errorf("无法创建请求: %w", err)
		}
		req.Header.Set("Content-Type", "application/json")
		req.Header.Set("Authorization", "Bearer "+apiKey)

		// 发送请求
		resp, err := client.Do(req)
		if err != nil {
			return nil, fmt.Errorf("无法发送请求: %w", err)
		}
		defer resp.Body.Close()

		// 检查响应状态
		if resp.StatusCode != http.StatusOK {
			return nil, errors.New("嵌入 API 返回错误响应: " + resp.Status)
		}

		// 读取并解码响应体
		body, err := io.ReadAll(resp.Body)
		if err != nil {
			return nil, fmt.Errorf("无法读取响应体: %w", err)
		}
		var embeddingResponse openAIResponse
		err = json.Unmarshal(body, &embeddingResponse)
		if err != nil {
			return nil, fmt.Errorf("无法反序列化响应体: %w", err)
		}

		// 检查响应中是否包含嵌入向量
		if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 {
			return nil, errors.New("响应中未找到嵌入向量")
		}

		v := embeddingResponse.Data[0].Embedding
		if normalized != nil {
			if *normalized {
				return v, nil
			}
			return normalizeVector(v), nil
		}
		checkNormalized.Do(func() {
			if isNormalized(v) {
				checkedNormalized = true
			} else {
				checkedNormalized = false
			}
		})
		if !checkedNormalized {
			v = normalizeVector(v)
		}

		return v, nil
	}
}