128 lines
2.8 KiB
Go
128 lines
2.8 KiB
Go
package mySql
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/Masterminds/squirrel"
|
|
"gorm.io/gorm"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Update struct {
|
|
tableName string
|
|
columns []*Column
|
|
wheres []*Where
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUpdate(tableName string, db *gorm.DB) *Update {
|
|
return &Update{tableName: tableName, db: db}
|
|
}
|
|
|
|
func (t *Update) MountColumn(data interface{}) *Update {
|
|
v := reflect.ValueOf(data)
|
|
v = v.Elem()
|
|
for i := 0; i < v.NumField(); i++ {
|
|
column := v.Type().Field(i).Tag.Get("column")
|
|
if column != "" {
|
|
countSplit := strings.SplitN(column, ":", 2)
|
|
if len(countSplit) == 2 {
|
|
if countSplit[1] != "key" {
|
|
t.AddColumn(countSplit[0], v.Field(i).Interface())
|
|
} else {
|
|
t.Where(countSplit[0], v.Field(i).Interface())
|
|
}
|
|
} else {
|
|
t.AddColumn(column, v.Field(i).Interface())
|
|
}
|
|
}
|
|
}
|
|
return t
|
|
}
|
|
|
|
func (t *Update) AddColumn(column string, value interface{}) *Update {
|
|
t.columns = append(t.columns, NewColumn(fmt.Sprintf("`%s`", column), value))
|
|
return t
|
|
}
|
|
|
|
func (t *Update) AddColumns(m map[string]interface{}) *Update {
|
|
for s, i := range m {
|
|
t.AddColumn(s, i)
|
|
}
|
|
return t
|
|
}
|
|
|
|
func (t *Update) AddColumnInc(column string, inc uint) *Update {
|
|
t.columns = append(t.columns, NewColumn(
|
|
fmt.Sprintf("`%s`", column),
|
|
squirrel.Expr(fmt.Sprintf("`%s` + %d", column, inc)),
|
|
))
|
|
return t
|
|
}
|
|
|
|
func (t *Update) AddColumnDec(column string, dec uint) *Update {
|
|
t.columns = append(t.columns, NewColumn(
|
|
fmt.Sprintf("`%s`", column),
|
|
squirrel.Expr(fmt.Sprintf("`%s` - %d", column, dec)),
|
|
))
|
|
return t
|
|
}
|
|
|
|
func (t *Update) AddCreatedColumn() *Update {
|
|
t.AddColumn("created_at", time.Now())
|
|
return t
|
|
}
|
|
|
|
func (t *Update) AddUpdatedColumn() *Update {
|
|
t.AddColumn("updated_at", time.Now())
|
|
return t
|
|
}
|
|
func (t *Update) AddDeletedColumn() *Update {
|
|
t.AddColumn("deleted_at", time.Now())
|
|
return t
|
|
}
|
|
func (t *Update) WhereRaw(formula string, values ...interface{}) *Update {
|
|
t.wheres = append(t.wheres, NewWhere(formula, values...))
|
|
return t
|
|
}
|
|
|
|
func (t *Update) Where(formula string, values ...interface{}) *Update {
|
|
t.WhereRaw(fmt.Sprintf("`%s` = ?", formula), values...)
|
|
return t
|
|
}
|
|
|
|
func (t *Update) Update() error {
|
|
mapper, err := t.Mapper()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return mapper.Exec().Error
|
|
}
|
|
|
|
func (t *Update) toSqu() (squirrel.UpdateBuilder, error) {
|
|
if len(t.wheres) == 0 || t.wheres[0].Formula == "" {
|
|
return squirrel.UpdateBuilder{}, errors.New("没有where条件不被允许")
|
|
}
|
|
|
|
squ := squirrel.Update(fmt.Sprintf("`%s`", t.tableName))
|
|
for _, v := range t.columns {
|
|
squ = squ.Set(v.name, v.value)
|
|
}
|
|
|
|
//拼接where
|
|
for _, where := range t.wheres {
|
|
squ = squ.Where(where.Formula, where.Values...)
|
|
}
|
|
return squ, nil
|
|
}
|
|
|
|
func (t *Update) Mapper() (*SqlMapper, error) {
|
|
squ, err := t.toSqu()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return Mapper(squ.ToSql()).setDB(t.db), nil
|
|
}
|