package mySql import ( "fmt" "github.com/Masterminds/squirrel" "gorm.io/gorm" "math" "reflect" "strconv" "strings" ) type Query struct { tableName string rawColumns []string autoColumns []string rawWheres []*Where autoWheres []*Where limit int64 offset int64 groupBy string orderBy string join []*Join isPageMode bool page int64 pageSize int64 alias string db *gorm.DB } func NewQuery(tableName string, db *gorm.DB) *Query { return &Query{ tableName: tableName, rawColumns: []string{}, autoColumns: []string{}, rawWheres: make([]*Where, 0), autoWheres: make([]*Where, 0), limit: 0, offset: 0, join: make([]*Join, 0), isPageMode: false, page: 1, pageSize: 30, db: db, } } func (t *Query) SetDB(db *gorm.DB) *Query { t.db = db return t } func (t *Query) As(alias string) *Query { t.alias = alias return t } func (t *Query) Select(columns ...string) *Query { t.autoColumns = []string{} for _, column := range columns { t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", column)) } return t } func (t *Query) SelectRow(columns ...string) *Query { t.rawColumns = columns return t } func (t *Query) AddSelect(columns ...string) *Query { for _, v := range columns { t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", v)) } return t } func (t *Query) AddSelectRow(columns ...string) *Query { for _, v := range columns { t.rawColumns = append(t.rawColumns, v) } return t } func (t *Query) GroupBy(s string) *Query { if find := strings.Contains(s, "."); find { countSplit := strings.SplitN(s, ".", 2) if len(countSplit) == 2 { t.groupBy = fmt.Sprintf("`%s`.`%s`", countSplit[0], countSplit[1]) } else { t.groupBy = fmt.Sprintf("`%s`", s) } } else { t.groupBy = fmt.Sprintf("`%s`", s) } return t } func (t *Query) GroupByRaw(s string) *Query { t.groupBy = s return t } func (t *Query) Limit(s int64) *Query { t.limit = s return t } func (t *Query) Offset(s int64) *Query { t.offset = s return t } func (t *Query) Page(p int64) *Query { t.isPageMode = true if p == 0 { p = 1 } t.page = p return t } func (t *Query) PageSize(p int64) *Query { if p == 0 { p = 30 } t.pageSize = p return t } func (t *Query) Pages(p int64, s int64) *Query { return t.Page(p).PageSize(s) } func (t *Query) MountQuery(f func(squ *Query)) *Query { f(t) return t } func (t *Query) NotDelete() *Query { return t.WhereIsNull("deleted_at") } func (t *Query) autoColumnsToRaw() []string { newColumns := make([]string, 0) for _, v := range t.autoColumns { if find := strings.Contains(v, "."); !find { if t.alias != "" { newColumns = append(newColumns, fmt.Sprintf("`%s`.%s", t.alias, v)) } else { newColumns = append(newColumns, v) } } else { newColumns = append(newColumns, v) } } return newColumns } func (t *Query) getColumns() []string { all := append(t.autoColumnsToRaw(), t.rawColumns...) if len(all) == 0 { return []string{"*"} } return all } func (t *Query) getFrom() string { if t.alias != "" { return fmt.Sprintf("`%s` AS `%s`", t.tableName, t.alias) } return fmt.Sprintf("`%s`", t.tableName) } func (t *Query) getColName(colName string) string { if t.alias != "" { return fmt.Sprintf("`%s`.`%s`", t.alias, colName) } else { return fmt.Sprintf("`%s`", colName) } } func (t *Query) mountWhere(squ squirrel.SelectBuilder) squirrel.SelectBuilder { //先挂载原生数据 for _, where := range t.rawWheres { squ = squ.Where(where.Formula, where.Values...) } //挂载需要加工的数据 for _, where := range t.autoWheres { squ = squ.Where(where.Formula, where.Values...) } return squ } func (t *Query) mountOther(squ squirrel.SelectBuilder) squirrel.SelectBuilder { if t.isPageMode { squ = squ.Offset(uint64(t.pageSize * (t.page - 1))).Limit(uint64(t.pageSize)) } else { if t.limit > 0 { squ = squ.Limit(uint64(t.limit)) } if t.offset > 0 { squ = squ.Offset(uint64(t.limit)) } } if t.groupBy != "" { squ = squ.GroupBy(t.groupBy) } if t.orderBy != "" { squ = squ.OrderBy(t.orderBy) } return squ } func (t *Query) mountJoin(squ squirrel.SelectBuilder) squirrel.SelectBuilder { for _, join := range t.join { if join.option == "left" { if len(join.rest) == 0 { squ = squ.LeftJoin(join.join) } else { squ = squ.LeftJoin(join.join, join.rest) } } else if join.option == "right" { if len(join.rest) == 0 { squ = squ.RightJoin(join.join) } else { squ = squ.RightJoin(join.join, join.rest) } } else if join.option == "inner" { if len(join.rest) == 0 { squ = squ.RightJoin(join.join) } else { squ = squ.RightJoin(join.join, join.rest) } } else { if len(join.rest) == 0 { squ = squ.RightJoin(join.join) } else { squ = squ.RightJoin(join.join, join.rest) } } } return squ } func (t *Query) ToSquirrel() squirrel.SelectBuilder { //获取select squ := squirrel.Select(t.getColumns()...).From(t.getFrom()) //组装where squ = t.mountWhere(squ) //组装limit和offset,groupBy和orderBy squ = t.mountOther(squ) //组装join squ = t.mountJoin(squ) return squ } func (t *Query) ToGormQuery() *gorm.DB { return t.toGormQuery(t.ToSquirrel()) } func (t *Query) ToSql() (string, []interface{}, error) { return t.ToSquirrel().ToSql() } func (t *Query) toGormQuery(squ squirrel.SelectBuilder) *gorm.DB { return Mapper(squ.ToSql()).setDB(t.db).Query() } func (t *Query) Find(list interface{}) error { return t.ToGormQuery().Find(list).Error } func (t *Query) Take(info interface{}) error { t.limit = 1 return t.ToGormQuery().Take(info).Error } func (t *Query) GetTotal() int64 { //只组装where squ := squirrel.Select("count(*)").From(t.getFrom()) //组装where squ = t.mountWhere(squ) var count int64 t.toGormQuery(squ).Scan(&count) return count } func (t *Query) GetRowCount() int64 { //只组装where squ := squirrel.Select("count(*)").From(t.getFrom()) //组装where squ = t.mountWhere(squ) //组装groupBy if t.groupBy != "" { squ = squ.GroupBy(t.groupBy) } var count int64 t.toGormQuery(squ).Count(&count) return count } func (t *Query) GetSum(column string) int64 { //只组装where squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom()) //组装where squ = t.mountWhere(squ) var sum int64 t.toGormQuery(squ).Scan(&sum) return sum } func (t *Query) GetSumInterface(column string, value interface{}) { //只组装where squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom()) //组装where squ = t.mountWhere(squ) t.toGormQuery(squ).Scan(&value) } func (t *Query) PageFind(list interface{}) (*PageStruct, error) { err1 := t.Find(list) total := t.GetTotal() p := &PageStruct{ Total: total, TotalPage: 0, Page: t.page, PageSize: t.pageSize, } totalPage, err := strconv.ParseFloat(fmt.Sprintf("%.5f", float64(total)/float64(t.pageSize)), 64) if err == nil { p.TotalPage = int64(math.Ceil(totalPage)) } else { p.TotalPage = (total / p.PageSize) + 1 } return p, err1 } func (t *Query) MountWhereForReflect(data interface{}) *Query { v := reflect.ValueOf(data) v = v.Elem() for i := 0; i < v.NumField(); i++ { where := v.Type().Field(i).Tag.Get("where") if where != "" && !isBlank(v.Field(i)) { split := NewStringSplit(where, ":") //如果是1段 split.RunCount1Func(func(whereColName string) { t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), v.Field(i).Interface()) }) //如果是2段 split.RunCount2Func(func(whereColName, whereOperation string) { //str2的值LIKE_BOTH、LIKE、BOOL等 whereOperationSplit := NewStringSplit(strings.ToUpper(whereOperation), "_") whereOperationSplit.RunCount1Func(func(operation string) { //这里的operation是LIKE_BOTH的LIKE if operation == "LIKE" { t.WhereRaw( fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), v.Field(i).Interface(), ) } if operation == "BOOL" { if v.Field(i).Interface() == "yes" { t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 1) } if v.Field(i).Interface() == "no" { t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 0) } } }) whereOperationSplit.RunCount2Func(func(operation, operation2 string) { //这里的operation是LIKE_BOTH的LIKE,operation2是BOTH if operation == "LIKE" { if operation2 == "BOTH" { t.WhereRaw( fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), fmt.Sprintf("%s%s%s", "%", v.Field(i).Interface(), "%"), ) } if operation2 == "LEFT" { t.WhereRaw( fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), fmt.Sprintf("%s%s", "%", v.Field(i).Interface()), ) } if operation2 == "RIGHT" { t.WhereRaw( fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), fmt.Sprintf("%s%s", v.Field(i).Interface(), "%"), ) } } }) }) } } return t } func (t *Query) SelectAllForReflect(data interface{}) *Query { v := reflect.ValueOf(data) v = v.Elem() for i := 0; i < v.NumField(); i++ { column := v.Type().Field(i).Tag.Get("column") if column != "" { //取第一个添加到列 split := NewStringSplit(column, ":") split.RunIndexFunc(0, func(str string) { t.AddSelect(str) }) } } return t }