mirror of https://gitee.com/godoos/godoos.git
19 changed files with 861 additions and 283 deletions
@ -0,0 +1,330 @@ |
|||
package mongodm |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"log" |
|||
"reflect" |
|||
"strings" |
|||
|
|||
"go.mongodb.org/mongo-driver/bson" |
|||
"go.mongodb.org/mongo-driver/mongo" |
|||
"go.mongodb.org/mongo-driver/mongo/options" |
|||
"gorm.io/gorm" |
|||
"gorm.io/gorm/clause" |
|||
) |
|||
|
|||
func (d MongoDBDialector) queryCallback(db *gorm.DB) { |
|||
collectionName := db.Statement.Table |
|||
coll := db.ConnPool.(*mongoConnPool).Collection(collectionName) |
|||
dest := db.Statement.Dest |
|||
stmt := db.Statement |
|||
// log.Printf("stmt.Clauses: %+v\n", stmt.Clauses)
|
|||
//log.Printf("stmt.Clauses: %+v\n", stmt.Clauses["SELECT"].Expression)
|
|||
// 检查是否有 GROUP BY 或 JOIN
|
|||
hasGroupBy := false |
|||
if _, ok := stmt.Clauses["GROUP BY"]; ok { |
|||
hasGroupBy = true |
|||
} |
|||
if _, ok := stmt.Clauses["JOIN"].Expression.(clause.Join); ok { |
|||
// 所有类型的 JOIN(包括 LEFT JOIN)都触发聚合查询
|
|||
hasGroupBy = true |
|||
} |
|||
|
|||
var limit int |
|||
if limitClause, ok := stmt.Clauses["LIMIT"].Expression.(clause.Limit); ok { |
|||
if limitClause.Limit != nil { |
|||
limit = *limitClause.Limit |
|||
} else { |
|||
// First() 等隐式调用 LIMIT 1
|
|||
limit = 1 |
|||
} |
|||
} |
|||
|
|||
// 构建查询条件
|
|||
whereClause, hasWhere := stmt.Clauses["WHERE"].Expression.(clause.Where) |
|||
// log.Printf("Where: %+v\n", whereClause)
|
|||
filter := bson.M{} |
|||
if hasWhere { |
|||
filter = convertWhereToBSON(whereClause) |
|||
|
|||
} |
|||
// 构建投影字段(SELECT)
|
|||
var projection bson.M |
|||
if isCountQuery(stmt.Clauses["SELECT"].Expression) { |
|||
count, err := coll.CountDocuments(context.Background(), filter) |
|||
if err != nil { |
|||
db.AddError(fmt.Errorf("failed to count documents: %w", err)) |
|||
return |
|||
} |
|||
err = setValue(db.Statement.Dest, count) |
|||
if err != nil { |
|||
db.AddError(err) |
|||
return |
|||
} |
|||
db.RowsAffected = count |
|||
return |
|||
} |
|||
//log.Printf("stmt.Clauses: %+v\n", stmt.Clauses["SELECT"].Expression)
|
|||
if selectClause, ok := stmt.Clauses["SELECT"].Expression.(clause.Select); ok { |
|||
// 判断是否是 count 查询
|
|||
exprStr := fmt.Sprintf("%v", selectClause) |
|||
log.Printf("exprStr: %s\n", exprStr) |
|||
|
|||
projection = convertSelectToBSON(selectClause) |
|||
} |
|||
|
|||
if hasGroupBy { |
|||
pipeline := buildMongoPipeline(stmt.Clauses, dest) |
|||
|
|||
cursor, err := coll.Aggregate(context.Background(), pipeline) |
|||
if err != nil { |
|||
db.AddError(err) |
|||
return |
|||
} |
|||
defer cursor.Close(context.Background()) |
|||
if limit == 1 { |
|||
// 聚合结果中取第一条
|
|||
if cursor.Next(context.Background()) { |
|||
if err := cursor.Decode(dest); err != nil { |
|||
db.AddError(err) |
|||
} |
|||
} else { |
|||
db.AddError(gorm.ErrRecordNotFound) |
|||
} |
|||
} else { |
|||
// 获取所有聚合结果
|
|||
if err := cursor.All(context.Background(), dest); err != nil { |
|||
db.AddError(err) |
|||
} |
|||
} |
|||
return |
|||
} |
|||
|
|||
if limit == 1 { |
|||
// 处理 .First()
|
|||
opts := options.FindOne().SetProjection(projection) |
|||
var result bson.M |
|||
err := coll.FindOne(context.Background(), filter, opts).Decode(&result) |
|||
if err != nil { |
|||
if err == mongo.ErrNoDocuments { |
|||
db.AddError(gorm.ErrRecordNotFound) |
|||
} else { |
|||
db.AddError(err) |
|||
} |
|||
return |
|||
} |
|||
|
|||
// 将结果映射到原始 dest(例如 *User)
|
|||
//log.Printf("first dest: %+v\n", dest)
|
|||
if err := mapBSONToStruct(result, dest); err != nil { |
|||
db.AddError(err) |
|||
} |
|||
} else { |
|||
// 处理 .Find()
|
|||
opts := options.Find().SetProjection(projection) |
|||
cursor, err := coll.Find(context.Background(), filter, opts) |
|||
if err != nil { |
|||
db.AddError(err) |
|||
return |
|||
} |
|||
defer cursor.Close(context.Background()) |
|||
|
|||
err = cursor.All(context.Background(), dest) |
|||
if err != nil { |
|||
db.AddError(err) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// =======================
|
|||
// 工具函数:将 Where/Join/GroupBy 等转换为聚合 Pipeline
|
|||
// =======================
|
|||
func buildMongoPipeline(clauses map[string]clause.Clause, dest interface{}) []bson.M { |
|||
var pipeline []bson.M |
|||
|
|||
// 处理 WHERE 条件
|
|||
if whereClause, ok := clauses["WHERE"].Expression.(clause.Where); ok { |
|||
filter := convertWhereToBSON(whereClause) |
|||
if len(filter) > 0 { |
|||
pipeline = append(pipeline, bson.M{"$match": filter}) |
|||
} |
|||
} |
|||
// 处理 JOIN
|
|||
if joinClause, ok := clauses["JOIN"].Expression.(clause.Join); ok { |
|||
lookupStage := buildJoinPipeline(joinClause, dest) |
|||
if lookupStage != nil { |
|||
pipeline = append(pipeline, lookupStage) |
|||
|
|||
// INNER JOIN 处理
|
|||
if joinClause.Type == clause.InnerJoin { |
|||
as := getNestedFieldName(dest, joinClause.Table.Name) |
|||
if as == "" { |
|||
as = strings.ToLower(joinClause.Table.Name) + "s" |
|||
} |
|||
unwindStage := bson.M{"$unwind": "$" + as} |
|||
pipeline = append(pipeline, unwindStage) |
|||
} |
|||
|
|||
// RIGHT JOIN 处理
|
|||
if joinClause.Type == clause.RightJoin { |
|||
as := getNestedFieldName(dest, joinClause.Table.Name) |
|||
if as == "" { |
|||
as = strings.ToLower(joinClause.Table.Name) + "s" |
|||
} |
|||
|
|||
// 展开嵌套字段
|
|||
unwindStage := bson.M{"$unwind": "$" + as} |
|||
|
|||
// 处理 NULL 值,确保右表记录始终存在
|
|||
addFieldsStage := bson.M{ |
|||
"$addFields": bson.M{ |
|||
as: bson.M{ |
|||
"$ifNull": []interface{}{ |
|||
"$" + as, |
|||
bson.M{"$literal": []interface{}{}}, |
|||
}, |
|||
}, |
|||
}, |
|||
} |
|||
|
|||
pipeline = append(pipeline, unwindStage) |
|||
pipeline = append(pipeline, addFieldsStage) |
|||
} |
|||
} |
|||
} |
|||
// 可选:处理 GroupBy、Select、Sort 等聚合操作
|
|||
// 示例:GroupBy + Count
|
|||
if groupByClause, ok := clauses["GROUP BY"].Expression.(clause.GroupBy); ok { |
|||
groupFields := bson.M{"_id": nil} |
|||
for _, item := range groupByClause.Columns { |
|||
groupFields["_id"] = "$" + getColumnName(item) |
|||
} |
|||
groupFields["count"] = bson.M{"$sum": 1} |
|||
pipeline = append(pipeline, bson.M{"$group": groupFields}) |
|||
} |
|||
|
|||
return pipeline |
|||
} |
|||
func buildJoinPipeline(join clause.Join, dest interface{}) bson.M { |
|||
// 默认值
|
|||
localField := "" |
|||
foreignField := "" |
|||
from := join.Table.Name |
|||
|
|||
for _, expr := range join.ON.Exprs { |
|||
if eq, ok := expr.(clause.Eq); ok { |
|||
if lc, lok := eq.Column.(clause.Column); lok { |
|||
localField = lc.Name |
|||
} |
|||
if strVal, ok := eq.Value.(string); ok { |
|||
foreignField = strVal |
|||
} |
|||
} |
|||
} |
|||
|
|||
if localField == "" || foreignField == "" { |
|||
return nil |
|||
} |
|||
|
|||
as := getNestedFieldName(dest, from) |
|||
if as == "" { |
|||
as = strings.ToLower(from) + "s" |
|||
} |
|||
|
|||
lookupStage := bson.M{ |
|||
"$lookup": bson.M{ |
|||
"from": from, |
|||
"localField": localField, |
|||
"foreignField": foreignField, |
|||
"as": as, |
|||
}, |
|||
} |
|||
|
|||
// 区分不同 JOIN 类型
|
|||
switch join.Type { |
|||
case clause.InnerJoin: |
|||
unwindStage := bson.M{"$unwind": "$" + as} |
|||
return bson.M{ |
|||
"$and": []bson.M{lookupStage, unwindStage}, |
|||
} |
|||
case clause.RightJoin: |
|||
addFieldsStage := bson.M{ |
|||
"$addFields": bson.M{ |
|||
as: bson.M{ |
|||
"$ifNull": []interface{}{ |
|||
"$" + as, |
|||
bson.M{"$literal": []interface{}{}}, |
|||
}, |
|||
}, |
|||
}, |
|||
} |
|||
unwindStage := bson.M{"$unwind": "$" + as} |
|||
return bson.M{ |
|||
"$and": []bson.M{lookupStage, addFieldsStage, unwindStage}, |
|||
} |
|||
case clause.LeftJoin: |
|||
// 左连接只需保留空数组作为默认值,不强制展开
|
|||
return lookupStage |
|||
default: |
|||
return lookupStage |
|||
} |
|||
} |
|||
|
|||
func getNestedFieldName(dest interface{}, collectionName string) string { |
|||
destType := reflect.TypeOf(dest).Elem() |
|||
|
|||
for i := 0; i < destType.NumField(); i++ { |
|||
field := destType.Field(i) |
|||
tag := field.Tag.Get("gorm") |
|||
|
|||
if tag != "" { |
|||
gormTag := parseGormTag(tag) |
|||
if gormTag["foreignKey"] == collectionName { |
|||
return field.Name |
|||
} |
|||
} |
|||
} |
|||
|
|||
return "" |
|||
} |
|||
func isCountQuery(expr clause.Expression) bool { |
|||
if exprStmt, ok := expr.(clause.Expr); ok { |
|||
exprStr := strings.ToLower(exprStmt.SQL) |
|||
//log.Printf("exprStr: %s\n", exprStr)
|
|||
|
|||
if strings.Contains(exprStr, "count(") { |
|||
//log.Println("Intercepted COUNT(*) query")
|
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// 支持 *int64、*int、**int64、***int64 等各种嵌套指针
|
|||
func setValue(dest interface{}, count int64) error { |
|||
destVal := reflect.ValueOf(dest) |
|||
if destVal.Kind() != reflect.Ptr { |
|||
return fmt.Errorf("destination must be a pointer, got %s", destVal.Kind()) |
|||
} |
|||
|
|||
elem := destVal.Elem() |
|||
for elem.Kind() == reflect.Ptr { |
|||
if elem.IsNil() { |
|||
elem.Set(reflect.New(elem.Type().Elem())) |
|||
} |
|||
elem = elem.Elem() |
|||
} |
|||
|
|||
//log.Printf("Before SetInt: %v (type: %s)", elem.Interface(), elem.Kind())
|
|||
|
|||
switch elem.Kind() { |
|||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|||
elem.SetInt(count) |
|||
default: |
|||
return fmt.Errorf("unsupported destination type: %s", elem.Kind()) |
|||
} |
|||
|
|||
//log.Printf("After SetInt: %v", elem.Interface())
|
|||
return nil |
|||
} |
@ -0,0 +1,221 @@ |
|||
package mongodm |
|||
|
|||
import ( |
|||
"reflect" |
|||
"strings" |
|||
|
|||
"go.mongodb.org/mongo-driver/bson" |
|||
"gorm.io/gorm/clause" |
|||
) |
|||
|
|||
// =======================
|
|||
// 工具函数:将 Where 转换为 BSON
|
|||
// =======================
|
|||
func convertWhereToBSON(where clause.Where) bson.M { |
|||
filter := bson.M{} |
|||
|
|||
for _, cond := range where.Exprs { |
|||
switch expr := cond.(type) { |
|||
case clause.Eq: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = expr.Value |
|||
} |
|||
case clause.Neq: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = bson.M{"$ne": expr.Value} |
|||
} |
|||
case clause.Gt: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = bson.M{"$gt": expr.Value} |
|||
} |
|||
case clause.Gte: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = bson.M{"$gte": expr.Value} |
|||
} |
|||
case clause.Lt: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = bson.M{"$lt": expr.Value} |
|||
} |
|||
case clause.Lte: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = bson.M{"$lte": expr.Value} |
|||
} |
|||
case clause.IN: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = bson.M{"$in": expr.Values} |
|||
} |
|||
case clause.Like: |
|||
if col, ok := expr.Column.(clause.Column); ok { |
|||
key := getColumnName(col) |
|||
filter[key] = bson.M{"$regex": expr.Value} |
|||
} |
|||
default: |
|||
if nativeCond, ok := cond.(clause.Expr); ok { |
|||
sqlStr := nativeCond.SQL |
|||
vals := nativeCond.Vars |
|||
|
|||
// 简单处理:支持形如 "column = ?" 或带 OR 的多个条件
|
|||
orClauses := parseNativeSQLConditions(sqlStr, vals) |
|||
|
|||
if len(orClauses) > 0 { |
|||
if _, exists := filter["$or"]; exists { |
|||
// 如果已有 $or,合并进去(用于 JOIN、GROUP BY 等复杂场景)
|
|||
existingOr := filter["$or"].([]bson.M) |
|||
filter["$or"] = append(existingOr, orClauses...) |
|||
} else { |
|||
filter["$or"] = orClauses |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
return filter |
|||
} |
|||
|
|||
// convertSelectToBSON 将 Select clause 转换为 MongoDB 的 projection
|
|||
func convertSelectToBSON(selectClause clause.Select) bson.M { |
|||
projection := bson.M{} |
|||
for _, col := range selectClause.Columns { |
|||
fieldName := getColumnName(col) |
|||
projection[fieldName] = 1 // 1 表示包含该字段
|
|||
} |
|||
return projection |
|||
} |
|||
|
|||
// 解析 SQL 片段,如 "username = ? OR email = ?",返回对应的 bson.M 列表
|
|||
// 支持更复杂的 SQL 表达式解析,包括:=, !=, >, <, >=, <=, IN, LIKE, IS NULL 等
|
|||
func parseNativeSQLConditions(sql string, values []interface{}) []bson.M { |
|||
var conditions []bson.M |
|||
sql = strings.TrimSpace(sql) |
|||
|
|||
// 分割 OR 条件
|
|||
orParts := splitSQLCondition(sql, "OR") |
|||
for _, orPart := range orParts { |
|||
orPart = strings.TrimSpace(orPart) |
|||
|
|||
// 处理单个 OR 子句中的 AND 条件(可嵌套)
|
|||
andParts := splitSQLCondition(orPart, "AND") |
|||
var andConditions []bson.M |
|||
|
|||
for _, andPart := range andParts { |
|||
andPart = strings.TrimSpace(andPart) |
|||
if andPart == "" { |
|||
continue |
|||
} |
|||
|
|||
cond := parseSingleCondition(andPart, values, len(conditions)+len(andConditions)) |
|||
if cond != nil { |
|||
andConditions = append(andConditions, cond) |
|||
} |
|||
} |
|||
|
|||
if len(andConditions) > 1 { |
|||
// 如果有多个 AND 条件,合并为一个 bson.M
|
|||
andFilter := bson.M{} |
|||
for _, c := range andConditions { |
|||
for k, v := range c { |
|||
andFilter[k] = v |
|||
} |
|||
} |
|||
conditions = append(conditions, andFilter) |
|||
} else if len(andConditions) == 1 { |
|||
conditions = append(conditions, andConditions[0]) |
|||
} |
|||
} |
|||
|
|||
return conditions |
|||
} |
|||
|
|||
// 分割 SQL 条件(支持 AND / OR)
|
|||
func splitSQLCondition(sql string, separator string) []string { |
|||
// 简化处理,避免破坏带引号的字符串中可能含有的 AND/OR
|
|||
// 实际生产环境可用 SQL 解析器或正则表达式加强匹配
|
|||
return strings.Split(strings.ToLower(sql), strings.ToLower(separator)) |
|||
} |
|||
|
|||
// 解析单个 SQL 条件,返回对应的 bson.M
|
|||
func parseSingleCondition(part string, values []interface{}, index int) bson.M { |
|||
part = strings.TrimSpace(part) |
|||
|
|||
// 跳过空条件
|
|||
if part == "" { |
|||
return nil |
|||
} |
|||
|
|||
// 匹配 IS NULL / IS NOT NULL
|
|||
if strings.Contains(strings.ToUpper(part), "IS NULL") { |
|||
fieldName := strings.TrimSpace(strings.Split(strings.ToUpper(part), "IS NULL")[0]) |
|||
return bson.M{fieldName: nil} |
|||
} |
|||
if strings.Contains(strings.ToUpper(part), "IS NOT NULL") { |
|||
fieldName := strings.TrimSpace(strings.Split(strings.ToUpper(part), "IS NOT NULL")[0]) |
|||
return bson.M{fieldName: bson.M{"$ne": nil}} |
|||
} |
|||
|
|||
// 匹配 LIKE
|
|||
if strings.Contains(strings.ToUpper(part), "LIKE") { |
|||
parts := strings.SplitN(part, "LIKE", 2) |
|||
if len(parts) == 2 && index < len(values) { |
|||
fieldName := strings.TrimSpace(parts[0]) |
|||
value := values[index] |
|||
if strVal, ok := value.(string); ok { |
|||
// 支持 %xxx% 的模糊匹配
|
|||
regexStr := strings.ReplaceAll(strVal, "%", ".*") |
|||
return bson.M{fieldName: bson.M{"$regex": regexStr, "$options": "i"}} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// 匹配 IN
|
|||
if strings.Contains(strings.ToUpper(part), "IN") { |
|||
parts := strings.SplitN(part, "IN", 2) |
|||
if len(parts) == 2 && index < len(values) { |
|||
fieldName := strings.TrimSpace(parts[0]) |
|||
if reflect.TypeOf(values[index]).Kind() == reflect.Slice { |
|||
slice := reflect.ValueOf(values[index]) |
|||
inValues := make([]interface{}, slice.Len()) |
|||
for i := 0; i < slice.Len(); i++ { |
|||
inValues[i] = slice.Index(i).Interface() |
|||
} |
|||
return bson.M{fieldName: bson.M{"$in": inValues}} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// 匹配比较运算符:>, >=, <, <=, !=
|
|||
comparisonOp := map[string]string{ |
|||
"!=": "$ne", |
|||
">": "$gt", |
|||
">=": "$gte", |
|||
"<": "$lt", |
|||
"<=": "$lte", |
|||
} |
|||
|
|||
for opStr, mongoOp := range comparisonOp { |
|||
if strings.Contains(part, opStr) { |
|||
parts := strings.Split(part, opStr) |
|||
if len(parts) >= 2 && index < len(values) { |
|||
fieldName := strings.TrimSpace(parts[0]) |
|||
return bson.M{fieldName: bson.M{mongoOp: values[index]}} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// 默认:支持等值查询 =
|
|||
if strings.Contains(part, "=") { |
|||
parts := strings.Split(part, "=") |
|||
if len(parts) >= 2 && index < len(values) { |
|||
fieldName := strings.TrimSpace(parts[0]) |
|||
return bson.M{fieldName: values[index]} |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
@ -1 +1 @@ |
|||
exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1 |
|||
exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1exit status 1 |
Binary file not shown.
Loading…
Reference in new issue