light-pkg/mySql/Update.go

128 lines
2.8 KiB
Go
Raw Permalink Normal View History

2024-12-18 03:41:30 +08:00
package mySql
import (
"errors"
"fmt"
"github.com/Masterminds/squirrel"
"gorm.io/gorm"
"reflect"
"strings"
"time"
)
type Update struct {
tableName string
columns []*Column
wheres []*Where
db *gorm.DB
}
func NewUpdate(tableName string, db *gorm.DB) *Update {
return &Update{tableName: tableName, db: db}
}
func (t *Update) MountColumn(data interface{}) *Update {
v := reflect.ValueOf(data)
v = v.Elem()
for i := 0; i < v.NumField(); i++ {
column := v.Type().Field(i).Tag.Get("column")
if column != "" {
countSplit := strings.SplitN(column, ":", 2)
if len(countSplit) == 2 {
if countSplit[1] != "key" {
t.AddColumn(countSplit[0], v.Field(i).Interface())
} else {
t.Where(countSplit[0], v.Field(i).Interface())
}
} else {
t.AddColumn(column, v.Field(i).Interface())
}
}
}
return t
}
func (t *Update) AddColumn(column string, value interface{}) *Update {
t.columns = append(t.columns, NewColumn(fmt.Sprintf("`%s`", column), value))
return t
}
func (t *Update) AddColumns(m map[string]interface{}) *Update {
for s, i := range m {
t.AddColumn(s, i)
}
return t
}
func (t *Update) AddColumnInc(column string, inc uint) *Update {
t.columns = append(t.columns, NewColumn(
fmt.Sprintf("`%s`", column),
squirrel.Expr(fmt.Sprintf("`%s` + %d", column, inc)),
))
return t
}
func (t *Update) AddColumnDec(column string, dec uint) *Update {
t.columns = append(t.columns, NewColumn(
fmt.Sprintf("`%s`", column),
squirrel.Expr(fmt.Sprintf("`%s` - %d", column, dec)),
))
return t
}
func (t *Update) AddCreatedColumn() *Update {
t.AddColumn("created_at", time.Now())
return t
}
func (t *Update) AddUpdatedColumn() *Update {
t.AddColumn("updated_at", time.Now())
return t
}
func (t *Update) AddDeletedColumn() *Update {
t.AddColumn("deleted_at", time.Now())
return t
}
func (t *Update) WhereRaw(formula string, values ...interface{}) *Update {
t.wheres = append(t.wheres, NewWhere(formula, values...))
return t
}
func (t *Update) Where(formula string, values ...interface{}) *Update {
t.WhereRaw(fmt.Sprintf("`%s` = ?", formula), values...)
return t
}
func (t *Update) Update() error {
mapper, err := t.Mapper()
if err != nil {
return err
}
return mapper.Exec().Error
}
func (t *Update) toSqu() (squirrel.UpdateBuilder, error) {
if len(t.wheres) == 0 || t.wheres[0].Formula == "" {
return squirrel.UpdateBuilder{}, errors.New("没有where条件不被允许")
}
squ := squirrel.Update(fmt.Sprintf("`%s`", t.tableName))
for _, v := range t.columns {
squ = squ.Set(v.name, v.value)
}
//拼接where
for _, where := range t.wheres {
squ = squ.Where(where.Formula, where.Values...)
}
return squ, nil
}
func (t *Update) Mapper() (*SqlMapper, error) {
squ, err := t.toSqu()
if err != nil {
return nil, err
}
return Mapper(squ.ToSql()).setDB(t.db), nil
}