You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

453 lines
12 KiB

package mongodm
import (
"context"
"fmt"
"reflect"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// MongoDBDialector 实现 GORM 的 Dialector 接口
type MongoDBDialector struct {
URI string // MongoDB 连接 URI
Database string // 默认数据库
Client *mongo.Client // MongoDB 客户端
}
// Name 返回该驱动的名称
func (d MongoDBDialector) Name() string {
return "mongodb"
}
// Initialize 实际连接 MongoDB
func (d *MongoDBDialector) Initialize(db *gorm.DB) error {
clientOptions := options.Client().ApplyURI(d.URI)
client, err := mongo.Connect(context.Background(), clientOptions)
if err != nil {
return err
}
// 设置连接健康检查
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err = client.Ping(ctx, nil); err != nil {
return err
}
// 保存client到结构体
d.Client = client
// 将 client 存入 GORM 的 ConnPool
db.ConnPool = &mongoConnPool{db: client.Database(d.Database)}
// 注册自定义 Clause Builders
db.ClauseBuilders["WHERE"] = d.buildWhereClause
db.ClauseBuilders["LIMIT"] = d.buildLimitClause
db.ClauseBuilders["ORDER BY"] = d.buildOrderByClause
// 注册 CURD 回调
db.Callback().Create().Replace("gorm:create", d.createCallback)
db.Callback().Query().Replace("gorm:query", d.queryCallback)
db.Callback().Update().Replace("gorm:update", d.updateCallback)
db.Callback().Delete().Replace("gorm:delete", d.deleteCallback)
db.Dialector = d
return nil
}
// Migrator 返回一个迁移器(MongoDB 不需要迁移)
func (d MongoDBDialector) Migrator(db *gorm.DB) gorm.Migrator {
return &MongoMigrator{db: db}
}
// DataTypeOf 返回字段在数据库中的数据类型(MongoDB 不适用,返回空字符串)
func (d MongoDBDialector) DataTypeOf(field *schema.Field) string {
return ""
}
// DefaultValueOf 返回字段的默认值表达式(MongoDB 不适用,返回 nil)
func (d MongoDBDialector) DefaultValueOf(field *schema.Field) clause.Expression {
return nil
}
// BindVarTo 处理变量绑定(模拟为 ? 占位符)
func (d MongoDBDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteString("?")
}
// QuoteTo 引用标识符(如字段名),用于防止关键字冲突
func (d MongoDBDialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString("\"" + str + "\"")
}
// Explain 格式化输出 SQL 语句(此处直接返回原样)
func (d MongoDBDialector) Explain(sql string, vars ...interface{}) string {
return sql
}
// =======================
// 自定义 Clause Builders
// =======================
func (d MongoDBDialector) buildWhereClause(c clause.Clause, builder clause.Builder) {
if where, ok := c.Expression.(clause.Where); ok {
builder.WriteString(" WHERE ")
for i, cond := range where.Exprs {
if i > 0 {
builder.WriteString(" AND ")
}
cond.Build(builder)
}
}
}
func (d MongoDBDialector) buildLimitClause(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok && limit.Limit != nil {
builder.WriteString(fmt.Sprintf(" LIMIT %d", *limit.Limit))
if limit.Offset > 0 {
builder.WriteString(fmt.Sprintf(" OFFSET %d", limit.Offset))
}
}
}
func (d MongoDBDialector) buildOrderByClause(c clause.Clause, builder clause.Builder) {
if orderBy, ok := c.Expression.(clause.OrderBy); ok {
builder.WriteString(" ORDER BY ")
for i, item := range orderBy.Columns {
if i > 0 {
builder.WriteString(", ")
}
d.QuoteTo(builder, item.Column.Name)
if !item.Desc {
builder.WriteString(" ASC")
} else {
builder.WriteString(" DESC")
}
}
}
}
// =======================
// CURD Callbacks
// =======================
func (d MongoDBDialector) createCallback(db *gorm.DB) {
collectionName := db.Statement.Table
coll := db.ConnPool.(*mongoConnPool).Collection(collectionName)
// 获取模型字段信息
stmt := db.Statement
for _, field := range stmt.Schema.Fields {
if field.AutoIncrementIncrement == 0 && field.TagSettings["AUTOINCREMENT"] == "true" {
// 如果是 autoIncrement 主键且值为空,则生成新 ID
nextID, err := GetNextID(d.Client, d.Database, collectionName)
if err != nil {
db.AddError(err)
return
}
// 设置自增 ID 到模型中(使用新的 Set 接口)
val := reflect.ValueOf(nextID).Convert(field.FieldType).Interface()
if err := field.Set(context.Background(), stmt.ReflectValue, val); err != nil {
db.AddError(err)
return
}
}
}
_, err := coll.InsertOne(context.Background(), stmt.Model)
if err != nil {
db.AddError(err)
}
}
func (d MongoDBDialector) queryCallback(db *gorm.DB) {
collectionName := db.Statement.Table
coll := db.ConnPool.(*mongoConnPool).Collection(collectionName)
// 构建聚合管道(传入 dest)
dest := db.Statement.Dest
pipeline := buildMongoPipeline(db.Statement.Clauses, dest)
// 执行聚合查询
cursor, err := coll.Aggregate(context.Background(), pipeline)
if err != nil {
db.AddError(err)
return
}
// 将结果映射到目标结构体
err = cursor.All(context.Background(), db.Statement.Dest)
if err != nil {
db.AddError(err)
}
}
func (d MongoDBDialector) updateCallback(db *gorm.DB) {
collectionName := db.Statement.Table
coll := db.ConnPool.(*mongoConnPool).Collection(collectionName)
var filter bson.M
if db.Statement.Clauses != nil {
whereClause, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where)
if ok {
filter = convertWhereToBSON(whereClause)
}
}
updateData := bson.M{"$set": db.Statement.Model}
_, err := coll.UpdateMany(context.Background(), filter, updateData)
if err != nil {
db.AddError(err)
}
}
func (d MongoDBDialector) deleteCallback(db *gorm.DB) {
collectionName := db.Statement.Table
coll := db.ConnPool.(*mongoConnPool).Collection(collectionName)
var filter bson.M
if db.Statement.Clauses != nil {
whereClause, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where)
if ok {
filter = convertWhereToBSON(whereClause)
}
}
_, err := coll.DeleteMany(context.Background(), filter)
if err != nil {
db.AddError(err)
}
}
// =======================
// 工具函数:将 Where/Join/GroupBy 等转换为聚合 Pipeline
// =======================
func buildMongoPipeline(clauses map[string]clause.Clause, dest interface{}) []bson.M {
var pipeline []bson.M
// 处理 WHERE 条件
if whereClause, ok := clauses["WHERE"].Expression.(clause.Where); ok {
filter := convertWhereToBSON(whereClause)
if len(filter) > 0 {
pipeline = append(pipeline, bson.M{"$match": filter})
}
}
// 处理 JOIN
if joinClause, ok := clauses["JOIN"].Expression.(clause.Join); ok {
lookupStage := buildJoinPipeline(joinClause, dest)
if lookupStage != nil {
pipeline = append(pipeline, lookupStage)
// INNER JOIN 处理
if joinClause.Type == clause.InnerJoin {
as := getNestedFieldName(dest, joinClause.Table.Name)
if as == "" {
as = strings.ToLower(joinClause.Table.Name) + "s"
}
unwindStage := bson.M{"$unwind": "$" + as}
pipeline = append(pipeline, unwindStage)
}
// RIGHT JOIN 处理
if joinClause.Type == clause.RightJoin {
as := getNestedFieldName(dest, joinClause.Table.Name)
if as == "" {
as = strings.ToLower(joinClause.Table.Name) + "s"
}
// 展开嵌套字段
unwindStage := bson.M{"$unwind": "$" + as}
// 处理 NULL 值,确保右表记录始终存在
addFieldsStage := bson.M{
"$addFields": bson.M{
as: bson.M{
"$ifNull": []interface{}{
"$" + as,
bson.M{"$literal": []interface{}{}},
},
},
},
}
pipeline = append(pipeline, unwindStage)
pipeline = append(pipeline, addFieldsStage)
}
}
}
// 可选:处理 GroupBy、Select、Sort 等聚合操作
// 示例:GroupBy + Count
if groupByClause, ok := clauses["GROUP BY"].Expression.(clause.GroupBy); ok {
groupFields := bson.M{"_id": nil}
for _, item := range groupByClause.Columns {
groupFields["_id"] = "$" + getColumnName(item)
}
groupFields["count"] = bson.M{"$sum": 1}
pipeline = append(pipeline, bson.M{"$group": groupFields})
}
return pipeline
}
func buildJoinPipeline(join clause.Join, dest interface{}) bson.M {
// 默认值
localField := ""
foreignField := ""
//as := join.As
from := join.Table.Name
// 遍历 ON 条件
for _, expr := range join.ON.Exprs {
if eq, ok := expr.(clause.Eq); ok {
// 处理左侧 Column
if lc, lok := eq.Column.(clause.Column); lok {
localField = lc.Name
}
// 处理右侧 Value(假设是 string 类型的字段名)
if strVal, ok := eq.Value.(string); ok {
foreignField = strVal
}
}
}
// 如果无法提取字段,则返回空
if localField == "" || foreignField == "" {
return nil
}
// 从 dest 结构体中获取嵌套字段名(如 Posts)
as := getNestedFieldName(dest, from)
if as == "" {
as = strings.ToLower(from) + "s" // 默认命名规则
}
// 返回 $lookup stage
return bson.M{
"$lookup": bson.M{
"from": from,
"localField": localField,
"foreignField": foreignField,
"as": as,
},
}
}
func buildUnwindPipeline(as string) bson.M {
return bson.M{"$unwind": "$" + as}
}
func getNestedFieldName(dest interface{}, collectionName string) string {
destType := reflect.TypeOf(dest).Elem()
for i := 0; i < destType.NumField(); i++ {
field := destType.Field(i)
tag := field.Tag.Get("gorm")
if tag != "" {
gormTag := parseGormTag(tag)
if gormTag["foreignKey"] == collectionName {
return field.Name
}
}
}
return ""
}
// =======================
// 工具函数:将 Where 转换为 BSON
// =======================
func convertWhereToBSON(where clause.Where) bson.M {
filter := bson.M{}
for _, cond := range where.Exprs {
switch expr := cond.(type) {
case clause.Eq:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = expr.Value
}
case clause.Neq:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = bson.M{"$ne": expr.Value}
}
case clause.Gt:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = bson.M{"$gt": expr.Value}
}
case clause.Gte:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = bson.M{"$gte": expr.Value}
}
case clause.Lt:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = bson.M{"$lt": expr.Value}
}
case clause.Lte:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = bson.M{"$lte": expr.Value}
}
case clause.IN:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = bson.M{"$in": expr.Values}
}
case clause.Like:
if col, ok := expr.Column.(clause.Column); ok {
key := getColumnName(col)
filter[key] = bson.M{"$regex": expr.Value}
}
default:
// 可选:处理未知条件或记录日志
}
}
return filter
}
// 辅助函数获取字段名(带表前缀)
func getColumnName(column clause.Column) string {
if column.Table == "" {
return column.Name
}
return fmt.Sprintf("%s.%s", column.Table, column.Name)
}
func (d *MongoDBDialector) Begin(db *gorm.DB) (err error) {
session, err := d.Client.StartSession()
if err != nil {
return err
}
db.InstanceSet("mongo:session", session)
session.StartTransaction()
return nil
}
func (d *MongoDBDialector) Commit(db *gorm.DB) (err error) {
session, ok := db.InstanceGet("mongo:session")
if !ok || session == nil {
return nil
}
sess := session.(mongo.Session)
err = sess.CommitTransaction(context.Background())
sess.EndSession(context.Background())
return err
}
func (d *MongoDBDialector) Rollback(db *gorm.DB) (err error) {
session, ok := db.InstanceGet("mongo:session")
if !ok || session == nil {
return nil
}
sess := session.(mongo.Session)
err = sess.AbortTransaction(context.Background())
sess.EndSession(context.Background())
return err
}