You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

163 lines
4.3 KiB

// adapter.go
package mongodm
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type mongoConnPool struct {
db *mongo.Database
}
type ConnPoolStats struct {
TotalConn int
InUse int
Idle int
WaitCount int64
WaitTime time.Duration
MaxIdleTime time.Duration
}
type Counter struct {
ID string `bson:"_id"` // collection name
Seq int64 `bson:"seq"` // 当前序列值
}
func (p *mongoConnPool) GetDB() *mongo.Database {
return p.db
}
func (p *mongoConnPool) Collection(name string) *mongo.Collection {
return p.db.Collection(name)
}
func (p *mongoConnPool) Conn(ctx context.Context) (driver.Conn, error) {
return &mongoConn{}, nil
}
func (p *mongoConnPool) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
return nil, nil // MongoDB 不使用 Prepare/Stmt,直接返回 nil
}
// 在 mongoConn 结构体后添加以下方法即可修复编译错误
func (c *mongoConn) Prepare(query string) (driver.Stmt, error) {
return nil, fmt.Errorf("Prepare not implemented")
}
func (p *mongoConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
// MongoDB 不支持原生 SQL,这里可以返回 nil 和 nil 表示忽略执行
// 或者根据业务需求做具体处理(如日志记录、错误提示等)
return nil, nil
}
func (p *mongoConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return &sql.Rows{}, nil
}
func (p *mongoConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
return nil
}
type mongoTx struct {
session mongo.Session // 现在 session 是 *mongo.Session 类型
}
func (p *mongoConnPool) BeginTx(ctx context.Context, opts interface{}) (interface{}, error) {
session, err := p.db.Client().StartSession()
if err != nil {
return nil, err
}
session.StartTransaction()
// 修改此处:取 session 的地址,使其成为 *mongo.Session 类型
return &mongoTx{session: session}, nil
}
func (t *mongoTx) Commit() error {
return t.session.CommitTransaction(context.Background())
}
func (t *mongoTx) Rollback() error {
return t.session.AbortTransaction(context.Background())
}
func (p *mongoConnPool) Ping(ctx context.Context) error {
return p.db.Client().Ping(ctx, nil)
}
func (p *mongoConnPool) Stats() ConnPoolStats {
return ConnPoolStats{}
}
func (p *mongoConnPool) Close() {}
type mongoConn struct{}
func (c *mongoConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return nil, fmt.Errorf("PrepareContext not implemented")
}
func (c *mongoConn) Close() error {
return nil
}
func (c *mongoConn) Begin() (driver.Tx, error) {
return nil, fmt.Errorf("Begin not implemented")
}
// CommitTx 提交事务
func (p *mongoConnPool) CommitTx(ctx context.Context) error {
session, ok := contextFromDB(ctx)
if !ok {
return nil
}
defer session.EndSession(ctx)
return session.CommitTransaction(ctx)
}
// RollbackTx 回滚事务
func (p *mongoConnPool) RollbackTx(ctx context.Context) error {
session, ok := contextFromDB(ctx)
if !ok {
return nil
}
defer session.EndSession(ctx)
return session.AbortTransaction(ctx)
}
// 从上下文中获取 session(需配合 GORM 上下文管理)
func contextFromDB(ctx context.Context) (mongo.Session, bool) {
value := ctx.Value("mongo:session")
if value == nil {
return nil, false
}
session, ok := value.(mongo.Session)
return session, ok
}
func InitCounter(client *mongo.Client, dbName, collectionName string) error {
counterColl := client.Database(dbName).Collection("counters")
_, err := counterColl.InsertOne(context.Background(), bson.M{
"_id": collectionName,
"seq": int64(0),
})
return err
}
func GetNextID(client *mongo.Client, dbName, collectionName string) (int64, error) {
counterColl := client.Database(dbName).Collection("counters")
filter := bson.M{"_id": collectionName}
update := bson.M{"$inc": bson.M{"seq": 1}}
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
var result struct {
ID string `bson:"_id"`
Seq int64 `bson:"seq"`
}
err := counterColl.FindOneAndUpdate(context.Background(), filter, update, opts).Decode(&result)
if err != nil {
return 0, err
}
return result.Seq, nil
}