mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
4.3 KiB
163 lines
4.3 KiB
// adapter.go
|
|
package mongodm
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"time"
|
|
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.mongodb.org/mongo-driver/mongo/options"
|
|
)
|
|
|
|
type mongoConnPool struct {
|
|
db *mongo.Database
|
|
}
|
|
type ConnPoolStats struct {
|
|
TotalConn int
|
|
InUse int
|
|
Idle int
|
|
WaitCount int64
|
|
WaitTime time.Duration
|
|
MaxIdleTime time.Duration
|
|
}
|
|
type Counter struct {
|
|
ID string `bson:"_id"` // collection name
|
|
Seq int64 `bson:"seq"` // 当前序列值
|
|
}
|
|
|
|
func (p *mongoConnPool) GetDB() *mongo.Database {
|
|
return p.db
|
|
}
|
|
func (p *mongoConnPool) Collection(name string) *mongo.Collection {
|
|
return p.db.Collection(name)
|
|
}
|
|
func (p *mongoConnPool) Conn(ctx context.Context) (driver.Conn, error) {
|
|
return &mongoConn{}, nil
|
|
}
|
|
func (p *mongoConnPool) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
|
|
return nil, nil // MongoDB 不使用 Prepare/Stmt,直接返回 nil
|
|
}
|
|
|
|
// 在 mongoConn 结构体后添加以下方法即可修复编译错误
|
|
|
|
func (c *mongoConn) Prepare(query string) (driver.Stmt, error) {
|
|
return nil, fmt.Errorf("Prepare not implemented")
|
|
}
|
|
|
|
func (p *mongoConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
// MongoDB 不支持原生 SQL,这里可以返回 nil 和 nil 表示忽略执行
|
|
// 或者根据业务需求做具体处理(如日志记录、错误提示等)
|
|
return nil, nil
|
|
}
|
|
|
|
func (p *mongoConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
return &sql.Rows{}, nil
|
|
}
|
|
|
|
func (p *mongoConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
|
return nil
|
|
}
|
|
|
|
type mongoTx struct {
|
|
session mongo.Session // 现在 session 是 *mongo.Session 类型
|
|
}
|
|
|
|
func (p *mongoConnPool) BeginTx(ctx context.Context, opts interface{}) (interface{}, error) {
|
|
session, err := p.db.Client().StartSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
session.StartTransaction()
|
|
// 修改此处:取 session 的地址,使其成为 *mongo.Session 类型
|
|
return &mongoTx{session: session}, nil
|
|
}
|
|
func (t *mongoTx) Commit() error {
|
|
return t.session.CommitTransaction(context.Background())
|
|
}
|
|
|
|
func (t *mongoTx) Rollback() error {
|
|
return t.session.AbortTransaction(context.Background())
|
|
}
|
|
func (p *mongoConnPool) Ping(ctx context.Context) error {
|
|
return p.db.Client().Ping(ctx, nil)
|
|
}
|
|
|
|
func (p *mongoConnPool) Stats() ConnPoolStats {
|
|
return ConnPoolStats{}
|
|
}
|
|
func (p *mongoConnPool) Close() {}
|
|
|
|
type mongoConn struct{}
|
|
|
|
func (c *mongoConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
|
return nil, fmt.Errorf("PrepareContext not implemented")
|
|
}
|
|
|
|
func (c *mongoConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (c *mongoConn) Begin() (driver.Tx, error) {
|
|
return nil, fmt.Errorf("Begin not implemented")
|
|
}
|
|
|
|
// CommitTx 提交事务
|
|
func (p *mongoConnPool) CommitTx(ctx context.Context) error {
|
|
session, ok := contextFromDB(ctx)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
defer session.EndSession(ctx)
|
|
return session.CommitTransaction(ctx)
|
|
}
|
|
|
|
// RollbackTx 回滚事务
|
|
func (p *mongoConnPool) RollbackTx(ctx context.Context) error {
|
|
session, ok := contextFromDB(ctx)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
defer session.EndSession(ctx)
|
|
return session.AbortTransaction(ctx)
|
|
}
|
|
|
|
// 从上下文中获取 session(需配合 GORM 上下文管理)
|
|
func contextFromDB(ctx context.Context) (mongo.Session, bool) {
|
|
value := ctx.Value("mongo:session")
|
|
if value == nil {
|
|
return nil, false
|
|
}
|
|
session, ok := value.(mongo.Session)
|
|
return session, ok
|
|
}
|
|
func InitCounter(client *mongo.Client, dbName, collectionName string) error {
|
|
counterColl := client.Database(dbName).Collection("counters")
|
|
_, err := counterColl.InsertOne(context.Background(), bson.M{
|
|
"_id": collectionName,
|
|
"seq": int64(0),
|
|
})
|
|
return err
|
|
}
|
|
func GetNextID(client *mongo.Client, dbName, collectionName string) (int64, error) {
|
|
counterColl := client.Database(dbName).Collection("counters")
|
|
|
|
filter := bson.M{"_id": collectionName}
|
|
update := bson.M{"$inc": bson.M{"seq": 1}}
|
|
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
|
|
|
|
var result struct {
|
|
ID string `bson:"_id"`
|
|
Seq int64 `bson:"seq"`
|
|
}
|
|
|
|
err := counterColl.FindOneAndUpdate(context.Background(), filter, update, opts).Decode(&result)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return result.Seq, nil
|
|
}
|
|
|