package mySql import ( "errors" "fmt" "gorm.io/gorm" ) type SqlMapper struct { Sql string Args []interface{} db *gorm.DB } func (t *SqlMapper) setDB(db *gorm.DB) *SqlMapper { t.db = db return t } // 查询 func (t *SqlMapper) Query() *gorm.DB { return t.db.Raw(t.Sql, t.Args...) } func (t *SqlMapper) Exec() *gorm.DB { return t.db.Exec(t.Sql, t.Args...) } func NewSqlMapper(sql string, args []interface{}) *SqlMapper { return &SqlMapper{Sql: sql, Args: args} } func Mapper(sql string, args []interface{}, err error) *SqlMapper { if err != nil { panic(err.Error()) } return NewSqlMapper(sql, args) } type SqlMappers []*SqlMapper func Mappers(sqlMappers ...*SqlMapper) (list SqlMappers) { list = sqlMappers return } func (t SqlMappers) apply(tx *gorm.DB) { for _, sql := range t { sql.setDB(tx) } } func (t SqlMappers) Exec(f func() error) error { if len(t) == 0 { return errors.New("无Mapper") } //其实是以第一个为准 return t[0].db.Transaction(func(tx *gorm.DB) error { t.apply(tx) return f() }) } func (t SqlMappers) ExecTransaction() error { //其实是以第一个为准 return t[0].db.Transaction(func(tx *gorm.DB) error { fmt.Println("事务开始") for _, sql := range t { sql.setDB(tx) err := sql.Exec().Error if err != nil { fmt.Println("事务结束(失败)") return err } } fmt.Println("事务成功") return nil }) }