mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
256 lines
6.2 KiB
256 lines
6.2 KiB
package model
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"godo/libs"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cavaliergopher/grab/v3"
|
|
)
|
|
|
|
const (
|
|
concurrency = 6 // 并发下载数
|
|
)
|
|
|
|
var downloads = make(map[string]*grab.Response)
|
|
var downloadsMutex sync.Mutex
|
|
|
|
//var cancelDownloads = make(map[string]context.CancelFunc)
|
|
|
|
func noticeSuccess(w http.ResponseWriter) {
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
|
//log.Println("Download starting!")
|
|
|
|
}
|
|
|
|
func Download(w http.ResponseWriter, r *http.Request) {
|
|
reqBody := ReqBody{}
|
|
err := json.NewDecoder(r.Body).Decode(&reqBody)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "first Decode request body error:"+err.Error())
|
|
return
|
|
}
|
|
err = LoadConfig()
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Load config error")
|
|
return
|
|
}
|
|
_, exitsModel := GetModel(reqBody.Model)
|
|
if exitsModel {
|
|
noticeSuccess(w)
|
|
return
|
|
|
|
}
|
|
if reqBody.Info.From == "ollama" && reqBody.Info.Engine == "ollama" {
|
|
setOllamaInfo(w, r, reqBody)
|
|
return
|
|
}
|
|
|
|
var paths []string
|
|
var tsize int64
|
|
for _, urls := range reqBody.Info.URL {
|
|
urls = replaceUrl(urls)
|
|
if !strings.HasPrefix(strings.ToLower(urls), "http://") && !strings.HasPrefix(strings.ToLower(urls), "https://") {
|
|
fileInfo, err := os.Stat(urls)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Get model path error")
|
|
return
|
|
}
|
|
tsize += fileInfo.Size()
|
|
paths = append(paths, urls)
|
|
continue
|
|
}
|
|
filePath, err := GetModelPath(urls, reqBody.Model, reqBody.Type)
|
|
//log.Printf("filePath is %s", filePath)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Get model path error")
|
|
return
|
|
}
|
|
paths = append(paths, filePath)
|
|
md5url := md5Url(urls)
|
|
if rsp, ok := downloads[md5url]; ok {
|
|
// 如果URL正在下载,跳过创建新的下载器实例
|
|
go trackProgress(w, rsp, md5url)
|
|
return
|
|
}
|
|
// 创建新的下载器实例
|
|
client := grab.NewClient()
|
|
client.HTTPClient = &http.Client{
|
|
Transport: &http.Transport{
|
|
MaxIdleConnsPerHost: concurrency, // 可选,设置并发连接数
|
|
},
|
|
}
|
|
log.Printf("filePath is %s", filePath)
|
|
// 创建下载请求
|
|
req, err := grab.NewRequest(filePath, urls)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Invalid download URL")
|
|
return
|
|
}
|
|
resp := client.Do(req)
|
|
downloads[md5url] = resp
|
|
//log.Printf("Download urls: %v\n", reqBody.DownloadUrl)
|
|
|
|
// // 跟踪进度
|
|
go trackProgress(w, resp, md5url)
|
|
tsize += resp.Size()
|
|
// 等待下载完成并检查错误
|
|
if err := resp.Err(); err != nil {
|
|
libs.ErrorMsg(w, "Download failed")
|
|
return
|
|
}
|
|
}
|
|
delUrls(reqBody.Info.URL)
|
|
if tsize <= 0 {
|
|
libs.ErrorMsg(w, "download size is zero")
|
|
return
|
|
}
|
|
reqBody.Info.Path = paths
|
|
reqBody.Status = "success"
|
|
reqBody.CreatedAt = time.Now()
|
|
// reqBody.Info["tsize"] = tsize
|
|
// reqBody.Info["size"] = humanReadableSize(tsize)
|
|
if reqBody.Info.From == "network" && reqBody.Info.Engine == "ollama" {
|
|
ConvertOllama(w, r, reqBody)
|
|
// reqBody.From = "ollama"
|
|
// reqBody.Paths = []string{}
|
|
}
|
|
|
|
if err := SetModel(reqBody); err != nil {
|
|
libs.ErrorMsg(w, "Set model error")
|
|
return
|
|
}
|
|
|
|
noticeSuccess(w)
|
|
}
|
|
func trackProgress(w http.ResponseWriter, resp *grab.Response, md5url string) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Recovered panic in trackProgress: %v", r)
|
|
}
|
|
downloadsMutex.Lock()
|
|
defer downloadsMutex.Unlock()
|
|
delete(downloads, md5url)
|
|
}()
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
log.Printf("Streaming unsupported")
|
|
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
fp := FileProgress{
|
|
Progress: resp.Progress(),
|
|
IsFinished: resp.IsComplete(),
|
|
Total: resp.Size(),
|
|
Current: resp.BytesComplete(),
|
|
Status: "loading",
|
|
}
|
|
//log.Printf("Progress: %v", fp)
|
|
if resp.IsComplete() && fp.Current == fp.Total {
|
|
fp.Status = "loaded"
|
|
}
|
|
jsonBytes, err := json.Marshal(fp)
|
|
if err != nil {
|
|
log.Printf("Error marshaling FileProgress to JSON: %v", err)
|
|
continue
|
|
}
|
|
if w != nil {
|
|
io.WriteString(w, string(jsonBytes))
|
|
w.Write([]byte("\n"))
|
|
flusher.Flush()
|
|
} else {
|
|
log.Println("ResponseWriter is nil, cannot send progress")
|
|
}
|
|
if fp.Status == "loaded" {
|
|
return
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
func md5Url(url string) string {
|
|
hasher := md5.New()
|
|
hasher.Write([]byte(url))
|
|
return hex.EncodeToString(hasher.Sum(nil))
|
|
}
|
|
func delUrls(reqUrl []string) {
|
|
if len(reqUrl) > 0 {
|
|
downloadsMutex.Lock()
|
|
defer downloadsMutex.Unlock()
|
|
for _, urls := range reqUrl {
|
|
urls = replaceUrl(urls)
|
|
md5url := md5Url(urls)
|
|
delete(downloads, md5url)
|
|
|
|
}
|
|
}
|
|
}
|
|
func DeleteFileHandle(w http.ResponseWriter, r *http.Request) {
|
|
var reqBody ReqBody
|
|
err := json.NewDecoder(r.Body).Decode(&reqBody)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Decode request body error: ")
|
|
return
|
|
}
|
|
err = LoadConfig()
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "Load config error: ")
|
|
return
|
|
}
|
|
if err := DeleteModel(reqBody.Model); err != nil {
|
|
libs.ErrorMsg(w, "Error deleting model")
|
|
return
|
|
}
|
|
if reqBody.Info.Engine == "ollama" {
|
|
postQuery := map[string]interface{}{"name": reqBody.Model}
|
|
url := GetOllamaUrl() + "/api/delete"
|
|
|
|
ForwardHandler(w, r, postQuery, url, "DELETE")
|
|
return
|
|
}
|
|
delUrls(reqBody.Info.URL)
|
|
|
|
// 尝试删除目录,注意这会递归删除目录下的所有内容
|
|
//dirPath := filepath.Dir(filePath)
|
|
dirPath, err := GetModelDir(reqBody.Model)
|
|
if err != nil {
|
|
libs.ErrorMsg(w, "GetModelDir error")
|
|
return
|
|
}
|
|
//log.Printf("delete dirpath %v", dirPath)
|
|
err = os.RemoveAll(dirPath)
|
|
if err != nil && !os.IsNotExist(err) {
|
|
libs.ErrorMsg(w, "Error removing directory")
|
|
return
|
|
} else if err == nil {
|
|
log.Printf("Deleted directory: %s", dirPath)
|
|
} else {
|
|
// 如果目录不存在,这通常是可以接受的,不需要错误消息
|
|
log.Printf("Directory does not exist: %s", dirPath)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]int{"code": 0})
|
|
|
|
}
|
|
|
|
func replaceUrl(url string) string {
|
|
return strings.ReplaceAll(url, "/blob/main/", "/resolve/main/")
|
|
}
|
|
|