diff --git a/godo/localchat/file.go b/godo/localchat/file.go index a959d64..9bfb755 100644 --- a/godo/localchat/file.go +++ b/godo/localchat/file.go @@ -26,7 +26,6 @@ package localchat import ( "encoding/base64" "encoding/json" - "errors" "fmt" "godo/libs" "io" @@ -45,7 +44,7 @@ const ( type FileChunk struct { ChunkIndex int `json:"chunk_index"` - Data string `json:"data"` + Data []byte `json:"data"` Checksum uint32 `json:"checksum"` Timestamp time.Time `json:"timestamp"` Filename string `json:"filename"` @@ -181,36 +180,18 @@ func SendFile(file *os.File, numChunks int, toIp string, fSize int64, message Ud if err != nil && err != io.EOF { log.Fatalf("Failed to read file chunk: %v", err) } - encodedData := base64.StdEncoding.EncodeToString(chunkData[:n]) - // 创建文件块 chunk := FileChunk{ ChunkIndex: index, - Data: encodedData, + Data: chunkData[:n], Checksum: calculateChecksum(chunkData[:n]), Timestamp: time.Now(), Filename: filepath.Base(file.Name()), Filesize: fSize, } - chunkJson, err := json.Marshal(chunk) - if err != nil { - log.Fatalf("Failed to marshal chunk: %v", err) - } - - // 确保每个数据包的大小不超过限制 - maxPacketSize := 65000 - if len(chunkJson) > maxPacketSize { - // 分割数据包 - chunks := splitChunkJson(chunkJson, maxPacketSize) - for _, subChunkJson := range chunks { - message.Message = base64.StdEncoding.EncodeToString(subChunkJson) - sendData(message, toIp) - } - } else { - message.Message = base64.StdEncoding.EncodeToString(chunkJson) - sendData(message, toIp) - } + message.Message = chunk + sendData(message, toIp) fmt.Printf("发送文件块 %d 到 %s 成功\n", index, toIp) }(i) @@ -219,18 +200,6 @@ func SendFile(file *os.File, numChunks int, toIp string, fSize int64, message Ud wg.Wait() } -func splitChunkJson(jsonData []byte, maxSize int) [][]byte { - var chunks [][]byte - for start := 0; start < len(jsonData); start += maxSize { - end := start + maxSize - if end > len(jsonData) { - end = len(jsonData) - } - chunks = append(chunks, jsonData[start:end]) - } - return chunks -} - func sendData(message UdpMessage, toIp string) { port := "56780" addr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%s", toIp, port)) @@ -251,34 +220,23 @@ func sendData(message UdpMessage, toIp string) { } } func ReceiveFile(msg UdpMessage) (string, error) { - chunkStr, ok := msg.Message.(string) + messageMap, ok := msg.Message.(map[string]interface{}) if !ok { - return "", errors.New("invalid message type") - } - - // Base64解码 - chunkJson, err := base64.StdEncoding.DecodeString(chunkStr) - if err != nil { - return "", fmt.Errorf("failed to decode base64 message: %v", err) + return "", fmt.Errorf("invalid message type: expected map[string]interface{}, got %T", msg.Message) } // 从 map 中提取 FileChunk 字段 - var chunk FileChunk - if err := json.Unmarshal(chunkJson, &chunk); err != nil { - return "", fmt.Errorf("failed to unmarshal FileChunk: %v", err) - } - - // 验证校验和 - chunkData, err := base64.StdEncoding.DecodeString(chunk.Data) + chunk, err := extractFileChunkFromMap(messageMap) if err != nil { - return "", fmt.Errorf("failed to decode base64 chunk data: %v", err) - } - calculatedChecksum := calculateChecksum(chunkData) - if calculatedChecksum != chunk.Checksum { - fmt.Printf("Checksum mismatch for chunk %d from %s\n", chunk.ChunkIndex, msg.IP) - return "", fmt.Errorf("checksum mismatch") + return "", err } + // calculatedChecksum := calculateChecksum(chunk.Data) + // if calculatedChecksum != chunk.Checksum { + // fmt.Printf("Checksum mismatch for chunk %d from %s\n", chunk.ChunkIndex, msg.IP) + // return "", fmt.Errorf("checksum mismatch") + // } + // 创建接收文件的目录 baseDir, err := libs.GetOsDir() if err != nil { @@ -318,13 +276,13 @@ func ReceiveFile(msg UdpMessage) (string, error) { defer file.Close() // 写入数据 - n, err := file.Write(chunkData) + n, err := file.Write(chunk.Data) if err != nil { log.Printf("Failed to write data to file: %v", err) return "", fmt.Errorf("failed to write data to file") } - if n != len(chunkData) { - log.Printf("Incomplete write: wrote %d bytes, expected %d bytes", n, len(chunkData)) + if n != len(chunk.Data) { + log.Printf("Incomplete write: wrote %d bytes, expected %d bytes", n, len(chunk.Data)) return "", fmt.Errorf("incomplete write") } @@ -348,3 +306,24 @@ func calculateChecksum(data []byte) uint32 { } return checksum } + +// 从 map 中提取 FileChunk 结构体 +func extractFileChunkFromMap(m map[string]interface{}) (FileChunk, error) { + chunk := FileChunk{} + + // 从 map 中提取字段 + chunk.ChunkIndex, _ = m["chunk_index"].(int) + dataStr, _ := m["data"].(string) + dataBytes, err := base64.StdEncoding.DecodeString(dataStr) + if err != nil { + return chunk, fmt.Errorf("failed to decode data: %v", err) + } + chunk.Data = dataBytes + chunk.Checksum, _ = m["checksum"].(uint32) + timestamp, _ := m["timestamp"].(string) + chunk.Timestamp, _ = time.Parse(time.RFC3339, timestamp) + chunk.Filename, _ = m["filename"].(string) + chunk.Filesize, _ = m["filesize"].(int64) + + return chunk, nil +}