package mySql import ( "errors" "fmt" "github.com/Masterminds/squirrel" "gorm.io/gorm" "reflect" "strings" "time" ) type Update struct { tableName string columns []*Column wheres []*Where db *gorm.DB } func NewUpdate(tableName string, db *gorm.DB) *Update { return &Update{tableName: tableName, db: db} } func (t *Update) MountColumn(data interface{}) *Update { v := reflect.ValueOf(data) v = v.Elem() for i := 0; i < v.NumField(); i++ { column := v.Type().Field(i).Tag.Get("column") if column != "" { countSplit := strings.SplitN(column, ":", 2) if len(countSplit) == 2 { if countSplit[1] != "key" { t.AddColumn(countSplit[0], v.Field(i).Interface()) } else { t.Where(countSplit[0], v.Field(i).Interface()) } } else { t.AddColumn(column, v.Field(i).Interface()) } } } return t } func (t *Update) AddColumn(column string, value interface{}) *Update { t.columns = append(t.columns, NewColumn(fmt.Sprintf("`%s`", column), value)) return t } func (t *Update) AddColumns(m map[string]interface{}) *Update { for s, i := range m { t.AddColumn(s, i) } return t } func (t *Update) AddColumnInc(column string, inc uint) *Update { t.columns = append(t.columns, NewColumn( fmt.Sprintf("`%s`", column), squirrel.Expr(fmt.Sprintf("`%s` + %d", column, inc)), )) return t } func (t *Update) AddColumnDec(column string, dec uint) *Update { t.columns = append(t.columns, NewColumn( fmt.Sprintf("`%s`", column), squirrel.Expr(fmt.Sprintf("`%s` - %d", column, dec)), )) return t } func (t *Update) AddCreatedColumn() *Update { t.AddColumn("created_at", time.Now()) return t } func (t *Update) AddUpdatedColumn() *Update { t.AddColumn("updated_at", time.Now()) return t } func (t *Update) AddDeletedColumn() *Update { t.AddColumn("deleted_at", time.Now()) return t } func (t *Update) WhereRaw(formula string, values ...interface{}) *Update { t.wheres = append(t.wheres, NewWhere(formula, values...)) return t } func (t *Update) Where(formula string, values ...interface{}) *Update { t.WhereRaw(fmt.Sprintf("`%s` = ?", formula), values...) return t } func (t *Update) Update() error { mapper, err := t.Mapper() if err != nil { return err } return mapper.Exec().Error } func (t *Update) toSqu() (squirrel.UpdateBuilder, error) { if len(t.wheres) == 0 || t.wheres[0].Formula == "" { return squirrel.UpdateBuilder{}, errors.New("没有where条件不被允许") } squ := squirrel.Update(fmt.Sprintf("`%s`", t.tableName)) for _, v := range t.columns { squ = squ.Set(v.name, v.value) } //拼接where for _, where := range t.wheres { squ = squ.Where(where.Formula, where.Values...) } return squ, nil } func (t *Update) Mapper() (*SqlMapper, error) { squ, err := t.toSqu() if err != nil { return nil, err } return Mapper(squ.ToSql()).setDB(t.db), nil }