diff --git a/godo/cmd/main.go b/godo/cmd/main.go index b0ea857..c10e000 100644 --- a/godo/cmd/main.go +++ b/godo/cmd/main.go @@ -100,6 +100,7 @@ func OsStart() { fileRouter.HandleFunc("/zip", files.HandleZip).Methods(http.MethodGet) fileRouter.HandleFunc("/unzip", files.HandleUnZip).Methods(http.MethodGet) fileRouter.HandleFunc("/watch", files.WatchHandler).Methods(http.MethodGet) + fileRouter.HandleFunc("/setfilepwd", files.HandleSetFilePwd).Methods(http.MethodGet) localchatRouter := router.PathPrefix("/localchat").Subrouter() localchatRouter.HandleFunc("/message", localchat.HandleMessage).Methods(http.MethodPost) diff --git a/godo/files/fs.go b/godo/files/fs.go index 4719135..d5408a5 100644 --- a/godo/files/fs.go +++ b/godo/files/fs.go @@ -24,9 +24,6 @@ package files import ( - "crypto/md5" - "encoding/base64" - "encoding/hex" "encoding/json" "fmt" "godo/libs" @@ -164,58 +161,6 @@ func HandleExists(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(res) } -// HandleReadFile reads a file's content -func HandleReadFile(w http.ResponseWriter, r *http.Request) { - path := r.URL.Query().Get("path") - fpwd := r.Header.Get("fpwd") - haspwd := IsHavePwd(fpwd) - // 校验文件路径 - if err := validateFilePath(path); err != nil { - libs.HTTPError(w, http.StatusBadRequest, err.Error()) - return - } - // 获取文件路径 - basePath, err := libs.GetOsDir() - if err != nil { - libs.HTTPError(w, http.StatusInternalServerError, err.Error()) - return - } - // 读取内容 - fileContent, err := ReadFile(basePath, path) - if err != nil { - libs.HTTPError(w, http.StatusNotFound, err.Error()) - return - } - content := string(fileContent) - // 检查文件内容是否以"link::"开头 - if !strings.HasPrefix(content, "link::") { - content = base64.StdEncoding.EncodeToString(fileContent) - } - - // 初始响应 - res := libs.APIResponse{Code: 0, Message: "success"} - switch haspwd { - case true: - // 有密码检验密码 - isreal := CheckFilePwd(fpwd) - // 密码正确返回原文,否则返回加密文本 - if isreal { - res.Data = content - } else { - data, err := libs.EncryptData(fileContent, libs.EncryptionKey) - if err != nil { - libs.HTTPError(w, http.StatusInternalServerError, err.Error()) - return - } - res.Data = base64.StdEncoding.EncodeToString(data) - } - case false: - res.Data = content - } - - json.NewEncoder(w).Encode(res) -} - // HandleUnlink removes a file func HandleUnlink(w http.ResponseWriter, r *http.Request) { path := r.URL.Query().Get("path") @@ -372,6 +317,7 @@ func HandleCopyFile(w http.ResponseWriter, r *http.Request) { // HandleWriteFile writes content to a file func HandleWriteFile(w http.ResponseWriter, r *http.Request) { + // basepath = "/Users/sujia/.godoos/os" filePath := r.URL.Query().Get("filePath") basePath, err := libs.GetOsDir() if err != nil { @@ -386,8 +332,13 @@ func HandleWriteFile(w http.ResponseWriter, r *http.Request) { return } defer fileContent.Close() - // 输出到控制台进行调试 - //fmt.Printf("Body content: %v\n", fileContent) + filedata, err := io.ReadAll(fileContent) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // 创建文件 file, err := os.Create(filepath.Join(basePath, filePath)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -395,15 +346,25 @@ func HandleWriteFile(w http.ResponseWriter, r *http.Request) { } defer file.Close() - _, err = io.Copy(file, fileContent) + // 内容为空直接返回,不为空则加密 + if len(filedata) == 0 { + CheckAddDesktop(filePath) + libs.SuccessMsg(w, "", "success") + return + } + // 加密 + data, err := libs.EncryptData(filedata, libs.EncryptionKey) if err != nil { - http.Error(w, err.Error(), http.StatusConflict) + http.Error(w, err.Error(), http.StatusInternalServerError) return } - err = CheckAddDesktop(filePath) + _, err = file.Write(data) if err != nil { - log.Printf("Error adding file to desktop: %s", err.Error()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return } + // 判断下是否添加到桌面上 + CheckAddDesktop(filePath) res := libs.APIResponse{Message: fmt.Sprintf("File '%s' successfully written.", filePath)} json.NewEncoder(w).Encode(res) } @@ -547,26 +508,3 @@ func HandleDesktop(w http.ResponseWriter, r *http.Request) { } libs.SuccessMsg(w, rootInfo, "success") } - -// 设置文件密码 -func HandleSetFilePwd(w http.ResponseWriter, r *http.Request) { - fpwd := r.Header.Get("filepwd") - // 密码最长16位 - if fpwd == "" || len(fpwd) > 16 { - libs.ErrorMsg(w, "密码长度为空或者过长,最长为16位") - return - } - // 服务端存储 - req := libs.ReqBody{ - Name: "filepwd", - Value: fpwd, - } - libs.SetConfig(req) - // 客户端加密 - mhash := md5.New() - mhash.Write([]byte(fpwd)) - v := mhash.Sum(nil) - pwdstr := hex.EncodeToString(v) - res := libs.APIResponse{Message: "success", Data: pwdstr} - json.NewEncoder(w).Encode(res) -} diff --git a/godo/files/os.go b/godo/files/os.go index 11efc8c..320a240 100644 --- a/godo/files/os.go +++ b/godo/files/os.go @@ -24,12 +24,11 @@ package files import ( - "crypto/md5" - "encoding/hex" "fmt" "godo/libs" "io" "io/fs" + "net/http" "os" "path/filepath" "strings" @@ -338,13 +337,13 @@ func CheckDeleteDesktop(filePath string) error { } // 校验文件密码 -func CheckFilePwd(fpwd string) bool { - mhash := md5.New() - mhash.Write([]byte(fpwd)) - v := mhash.Sum(nil) - pwdstr := hex.EncodeToString(v) - oldpwd, _ := libs.GetConfig("filepwd") - return oldpwd == pwdstr +func CheckFilePwd(fpwd, salt string) bool { + pwd := libs.HashPassword(fpwd, salt) + oldpwd, err := libs.GetConfig("filepwd") + if !err { + return false + } + return oldpwd == pwd } func IsHavePwd(pwd string) bool { @@ -354,3 +353,15 @@ func IsHavePwd(pwd string) bool { return false } } + +// salt值优先从server端获取,如果没有则从header获取 +func GetSalt(r *http.Request) string { + data, ishas := libs.GetConfig("salt") + salt := data.(string) + if ishas { + return salt + } else { + salt = r.Header.Get("salt") + return salt + } +} diff --git a/godo/files/pwdfile.go b/godo/files/pwdfile.go new file mode 100644 index 0000000..0d2c1b1 --- /dev/null +++ b/godo/files/pwdfile.go @@ -0,0 +1,102 @@ +package files + +import ( + "crypto/md5" + "encoding/base64" + "encoding/hex" + "encoding/json" + "godo/libs" + "net/http" + "strings" +) + +// 带加密读 +func HandleReadFile(w http.ResponseWriter, r *http.Request) { + + path := r.URL.Query().Get("path") + fpwd := r.Header.Get("fpwd") + haspwd := IsHavePwd(fpwd) + + // 获取salt值 + salt := GetSalt(r) + + // 校验文件路径 + if err := validateFilePath(path); err != nil { + libs.HTTPError(w, http.StatusBadRequest, err.Error()) + return + } + + // 有密码校验密码 + if haspwd { + if !CheckFilePwd(fpwd, salt) { + libs.HTTPError(w, http.StatusBadRequest, "密码错误") + return + } + } + + // 获取文件路径 + basePath, err := libs.GetOsDir() + if err != nil { + libs.HTTPError(w, http.StatusInternalServerError, err.Error()) + return + } + // 读取内容 + fileContent, err := ReadFile(basePath, path) + if err != nil { + libs.HTTPError(w, http.StatusNotFound, err.Error()) + return + } + + // 解密 + data, err := libs.DecryptData(fileContent, libs.EncryptionKey) + if err != nil { + libs.HTTPError(w, http.StatusInternalServerError, err.Error()) + return + } + + content := string(data) + // 检查文件内容是否以"link::"开头 + if !strings.HasPrefix(content, "link::") { + content = base64.StdEncoding.EncodeToString(data) + } + + // 初始响应 + res := libs.APIResponse{Code: 0, Message: "success", Data: content} + + json.NewEncoder(w).Encode(res) +} + +// 设置文件密码 +func HandleSetFilePwd(w http.ResponseWriter, r *http.Request) { + fpwd := r.Header.Get("filepwd") + salt := r.Header.Get("salt") + // 密码最长16位 + if fpwd == "" || len(fpwd) > 16 { + libs.ErrorMsg(w, "密码长度为空或者过长,最长为16位") + return + } + // md5加密 + mhash := md5.New() + mhash.Write([]byte(fpwd)) + v := mhash.Sum(nil) + pwdstr := hex.EncodeToString(v) + + // 服务端再hash加密 + hashpwd := libs.HashPassword(pwdstr, salt) + + // 服务端存储 + req := libs.ReqBody{ + Name: "filepwd", + Value: hashpwd, + } + libs.SetConfig(req) + + // salt值存储 + reqSalt := libs.ReqBody{ + Name: "salt", + Value: salt, + } + libs.SetConfig(reqSalt) + res := libs.APIResponse{Message: "success", Data: pwdstr} + json.NewEncoder(w).Encode(res) +} diff --git a/godo/libs/encode.go b/godo/libs/encode.go index 3ff3b03..007b98a 100644 --- a/godo/libs/encode.go +++ b/godo/libs/encode.go @@ -7,6 +7,8 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" + "encoding/base64" + "errors" "io" ) @@ -20,6 +22,19 @@ func pkcs7Pad(data []byte, blockSize int) []byte { return append(data, padtext...) } +// pkcs7Unpad 移除 PKCS#7 填充 +func pkcs7Unpad(data []byte) []byte { + length := len(data) + if length == 0 { + return data + } + padding := int(data[length-1]) // 将 padding 转换为 int 类型 + if padding > aes.BlockSize || padding < 1 { + return data + } + return data[:length-padding] +} + func EncryptData(data []byte, key []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { @@ -50,3 +65,53 @@ func EncryptData(data []byte, key []byte) ([]byte, error) { return result, nil } + +// DecryptData 使用 AES 解密数据,并验证 HMAC-SHA256 签名 +func DecryptData(ciphertext []byte, key []byte) ([]byte, error) { + // 检查 HMAC-SHA256 签名 + expectedMacSize := sha256.Size + if len(ciphertext) < expectedMacSize { + return nil, errors.New("ciphertext too short") + } + + macSum := ciphertext[len(ciphertext)-expectedMacSize:] + ciphertext = ciphertext[:len(ciphertext)-expectedMacSize] + + // 验证 HMAC-SHA256 签名 + mac := hmac.New(sha256.New, key) + mac.Write(ciphertext) + calculatedMac := mac.Sum(nil) + + if !hmac.Equal(macSum, calculatedMac) { + return nil, errors.New("invalid MAC") + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // 检查 IV 的长度 + if len(ciphertext) < aes.BlockSize { + return nil, errors.New("ciphertext too short") + } + + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + + // 使用 CBC 模式解密数据 + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(ciphertext, ciphertext) + + // 移除 PKCS#7 填充 + unpaddedData := pkcs7Unpad(ciphertext) + + return unpaddedData, nil +} + +// 哈希加密 +func HashPassword(password, salt string) string { + hash := sha256.New() + hash.Write([]byte(password + salt)) + return base64.URLEncoding.EncodeToString(hash.Sum(nil)) +}