421 lines
9.3 KiB
Go
421 lines
9.3 KiB
Go
package mySql
|
||
|
||
import (
|
||
"fmt"
|
||
"github.com/Masterminds/squirrel"
|
||
"gorm.io/gorm"
|
||
"math"
|
||
"reflect"
|
||
"strconv"
|
||
"strings"
|
||
)
|
||
|
||
type Query struct {
|
||
tableName string
|
||
rawColumns []string
|
||
autoColumns []string
|
||
rawWheres []*Where
|
||
autoWheres []*Where
|
||
limit int64
|
||
offset int64
|
||
groupBy string
|
||
orderBy string
|
||
join []*Join
|
||
isPageMode bool
|
||
page int64
|
||
pageSize int64
|
||
alias string
|
||
db *gorm.DB
|
||
}
|
||
|
||
func NewQuery(tableName string, db *gorm.DB) *Query {
|
||
return &Query{
|
||
tableName: tableName,
|
||
rawColumns: []string{},
|
||
autoColumns: []string{},
|
||
rawWheres: make([]*Where, 0),
|
||
autoWheres: make([]*Where, 0),
|
||
limit: 0,
|
||
offset: 0,
|
||
join: make([]*Join, 0),
|
||
isPageMode: false,
|
||
page: 1,
|
||
pageSize: 30,
|
||
db: db,
|
||
}
|
||
}
|
||
|
||
func (t *Query) SetDB(db *gorm.DB) *Query {
|
||
t.db = db
|
||
return t
|
||
}
|
||
|
||
func (t *Query) As(alias string) *Query {
|
||
t.alias = alias
|
||
return t
|
||
}
|
||
|
||
func (t *Query) Select(columns ...string) *Query {
|
||
t.autoColumns = []string{}
|
||
for _, column := range columns {
|
||
t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", column))
|
||
}
|
||
return t
|
||
}
|
||
|
||
func (t *Query) SelectRow(columns ...string) *Query {
|
||
t.rawColumns = columns
|
||
return t
|
||
}
|
||
|
||
func (t *Query) AddSelect(columns ...string) *Query {
|
||
for _, v := range columns {
|
||
t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", v))
|
||
}
|
||
return t
|
||
}
|
||
func (t *Query) AddSelectRow(columns ...string) *Query {
|
||
for _, v := range columns {
|
||
t.rawColumns = append(t.rawColumns, v)
|
||
}
|
||
return t
|
||
}
|
||
|
||
func (t *Query) GroupBy(s string) *Query {
|
||
if find := strings.Contains(s, "."); find {
|
||
countSplit := strings.SplitN(s, ".", 2)
|
||
if len(countSplit) == 2 {
|
||
t.groupBy = fmt.Sprintf("`%s`.`%s`", countSplit[0], countSplit[1])
|
||
} else {
|
||
t.groupBy = fmt.Sprintf("`%s`", s)
|
||
}
|
||
} else {
|
||
t.groupBy = fmt.Sprintf("`%s`", s)
|
||
}
|
||
return t
|
||
}
|
||
func (t *Query) GroupByRaw(s string) *Query {
|
||
t.groupBy = s
|
||
return t
|
||
}
|
||
|
||
func (t *Query) Limit(s int64) *Query {
|
||
t.limit = s
|
||
return t
|
||
}
|
||
func (t *Query) Offset(s int64) *Query {
|
||
t.offset = s
|
||
return t
|
||
}
|
||
|
||
func (t *Query) Page(p int64) *Query {
|
||
t.isPageMode = true
|
||
if p == 0 {
|
||
p = 1
|
||
}
|
||
t.page = p
|
||
return t
|
||
}
|
||
|
||
func (t *Query) PageSize(p int64) *Query {
|
||
if p == 0 {
|
||
p = 30
|
||
}
|
||
t.pageSize = p
|
||
return t
|
||
}
|
||
|
||
func (t *Query) Pages(p int64, s int64) *Query {
|
||
return t.Page(p).PageSize(s)
|
||
}
|
||
|
||
func (t *Query) MountQuery(f func(squ *Query)) *Query {
|
||
f(t)
|
||
return t
|
||
}
|
||
|
||
func (t *Query) NotDelete() *Query {
|
||
return t.WhereIsNull("deleted_at")
|
||
}
|
||
|
||
func (t *Query) autoColumnsToRaw() []string {
|
||
newColumns := make([]string, 0)
|
||
for _, v := range t.autoColumns {
|
||
if find := strings.Contains(v, "."); !find {
|
||
if t.alias != "" {
|
||
newColumns = append(newColumns, fmt.Sprintf("`%s`.%s", t.alias, v))
|
||
} else {
|
||
newColumns = append(newColumns, v)
|
||
}
|
||
} else {
|
||
newColumns = append(newColumns, v)
|
||
}
|
||
}
|
||
return newColumns
|
||
}
|
||
|
||
func (t *Query) getColumns() []string {
|
||
all := append(t.autoColumnsToRaw(), t.rawColumns...)
|
||
if len(all) == 0 {
|
||
return []string{"*"}
|
||
}
|
||
return all
|
||
}
|
||
|
||
func (t *Query) getFrom() string {
|
||
if t.alias != "" {
|
||
return fmt.Sprintf("`%s` AS `%s`", t.tableName, t.alias)
|
||
}
|
||
return fmt.Sprintf("`%s`", t.tableName)
|
||
}
|
||
|
||
func (t *Query) getColName(colName string) string {
|
||
if t.alias != "" {
|
||
return fmt.Sprintf("`%s`.`%s`", t.alias, colName)
|
||
} else {
|
||
return fmt.Sprintf("`%s`", colName)
|
||
}
|
||
}
|
||
|
||
func (t *Query) mountWhere(squ squirrel.SelectBuilder) squirrel.SelectBuilder {
|
||
//先挂载原生数据
|
||
for _, where := range t.rawWheres {
|
||
squ = squ.Where(where.Formula, where.Values...)
|
||
}
|
||
|
||
//挂载需要加工的数据
|
||
for _, where := range t.autoWheres {
|
||
squ = squ.Where(where.Formula, where.Values...)
|
||
}
|
||
|
||
return squ
|
||
}
|
||
|
||
func (t *Query) mountOther(squ squirrel.SelectBuilder) squirrel.SelectBuilder {
|
||
if t.isPageMode {
|
||
squ = squ.Offset(uint64(t.pageSize * (t.page - 1))).Limit(uint64(t.pageSize))
|
||
} else {
|
||
if t.limit > 0 {
|
||
squ = squ.Limit(uint64(t.limit))
|
||
}
|
||
if t.offset > 0 {
|
||
squ = squ.Offset(uint64(t.limit))
|
||
}
|
||
}
|
||
|
||
if t.groupBy != "" {
|
||
squ = squ.GroupBy(t.groupBy)
|
||
}
|
||
if t.orderBy != "" {
|
||
squ = squ.OrderBy(t.orderBy)
|
||
}
|
||
|
||
return squ
|
||
}
|
||
func (t *Query) mountJoin(squ squirrel.SelectBuilder) squirrel.SelectBuilder {
|
||
for _, join := range t.join {
|
||
if join.option == "left" {
|
||
if len(join.rest) == 0 {
|
||
squ = squ.LeftJoin(join.join)
|
||
} else {
|
||
squ = squ.LeftJoin(join.join, join.rest)
|
||
}
|
||
} else if join.option == "right" {
|
||
if len(join.rest) == 0 {
|
||
squ = squ.RightJoin(join.join)
|
||
} else {
|
||
squ = squ.RightJoin(join.join, join.rest)
|
||
}
|
||
} else if join.option == "inner" {
|
||
if len(join.rest) == 0 {
|
||
squ = squ.RightJoin(join.join)
|
||
} else {
|
||
squ = squ.RightJoin(join.join, join.rest)
|
||
}
|
||
} else {
|
||
if len(join.rest) == 0 {
|
||
squ = squ.RightJoin(join.join)
|
||
} else {
|
||
squ = squ.RightJoin(join.join, join.rest)
|
||
}
|
||
}
|
||
}
|
||
return squ
|
||
}
|
||
|
||
func (t *Query) ToSquirrel() squirrel.SelectBuilder {
|
||
//获取select
|
||
squ := squirrel.Select(t.getColumns()...).From(t.getFrom())
|
||
//组装where
|
||
squ = t.mountWhere(squ)
|
||
|
||
//组装limit和offset,groupBy和orderBy
|
||
squ = t.mountOther(squ)
|
||
|
||
//组装join
|
||
squ = t.mountJoin(squ)
|
||
return squ
|
||
}
|
||
|
||
func (t *Query) ToGormQuery() *gorm.DB {
|
||
return t.toGormQuery(t.ToSquirrel())
|
||
}
|
||
|
||
func (t *Query) ToSql() (string, []interface{}, error) {
|
||
return t.ToSquirrel().ToSql()
|
||
}
|
||
|
||
func (t *Query) toGormQuery(squ squirrel.SelectBuilder) *gorm.DB {
|
||
return Mapper(squ.ToSql()).setDB(t.db).Query()
|
||
}
|
||
|
||
func (t *Query) Find(list interface{}) error {
|
||
return t.ToGormQuery().Find(list).Error
|
||
}
|
||
|
||
func (t *Query) Take(info interface{}) error {
|
||
t.limit = 1
|
||
return t.ToGormQuery().Take(info).Error
|
||
}
|
||
|
||
func (t *Query) GetTotal() int64 {
|
||
//只组装where
|
||
squ := squirrel.Select("count(*)").From(t.getFrom())
|
||
//组装where
|
||
squ = t.mountWhere(squ)
|
||
|
||
var count int64
|
||
t.toGormQuery(squ).Scan(&count)
|
||
return count
|
||
}
|
||
|
||
func (t *Query) GetRowCount() int64 {
|
||
//只组装where
|
||
squ := squirrel.Select("count(*)").From(t.getFrom())
|
||
//组装where
|
||
squ = t.mountWhere(squ)
|
||
//组装groupBy
|
||
if t.groupBy != "" {
|
||
squ = squ.GroupBy(t.groupBy)
|
||
}
|
||
var count int64
|
||
t.toGormQuery(squ).Count(&count)
|
||
return count
|
||
}
|
||
|
||
func (t *Query) GetSum(column string) int64 {
|
||
//只组装where
|
||
squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom())
|
||
//组装where
|
||
squ = t.mountWhere(squ)
|
||
var sum int64
|
||
t.toGormQuery(squ).Scan(&sum)
|
||
return sum
|
||
}
|
||
|
||
func (t *Query) GetSumInterface(column string, value interface{}) {
|
||
//只组装where
|
||
squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom())
|
||
//组装where
|
||
squ = t.mountWhere(squ)
|
||
t.toGormQuery(squ).Scan(&value)
|
||
}
|
||
|
||
func (t *Query) PageFind(list interface{}) (*PageStruct, error) {
|
||
err1 := t.Find(list)
|
||
total := t.GetTotal()
|
||
p := &PageStruct{
|
||
Total: total,
|
||
TotalPage: 0,
|
||
Page: t.page,
|
||
PageSize: t.pageSize,
|
||
}
|
||
|
||
totalPage, err := strconv.ParseFloat(fmt.Sprintf("%.5f", float64(total)/float64(t.pageSize)), 64)
|
||
if err == nil {
|
||
p.TotalPage = int64(math.Ceil(totalPage))
|
||
} else {
|
||
p.TotalPage = (total / p.PageSize) + 1
|
||
}
|
||
|
||
return p, err1
|
||
}
|
||
|
||
func (t *Query) MountWhereForReflect(data interface{}) *Query {
|
||
v := reflect.ValueOf(data)
|
||
v = v.Elem()
|
||
for i := 0; i < v.NumField(); i++ {
|
||
where := v.Type().Field(i).Tag.Get("where")
|
||
if where != "" && !isBlank(v.Field(i)) {
|
||
split := NewStringSplit(where, ":")
|
||
//如果是1段
|
||
split.RunCount1Func(func(whereColName string) {
|
||
t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), v.Field(i).Interface())
|
||
})
|
||
|
||
//如果是2段
|
||
split.RunCount2Func(func(whereColName, whereOperation string) {
|
||
//str2的值LIKE_BOTH、LIKE、BOOL等
|
||
whereOperationSplit := NewStringSplit(strings.ToUpper(whereOperation), "_")
|
||
whereOperationSplit.RunCount1Func(func(operation string) {
|
||
//这里的operation是LIKE_BOTH的LIKE
|
||
if operation == "LIKE" {
|
||
t.WhereRaw(
|
||
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
|
||
v.Field(i).Interface(),
|
||
)
|
||
}
|
||
if operation == "BOOL" {
|
||
if v.Field(i).Interface() == "yes" {
|
||
t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 1)
|
||
}
|
||
if v.Field(i).Interface() == "no" {
|
||
t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 0)
|
||
}
|
||
}
|
||
})
|
||
whereOperationSplit.RunCount2Func(func(operation, operation2 string) {
|
||
//这里的operation是LIKE_BOTH的LIKE,operation2是BOTH
|
||
if operation == "LIKE" {
|
||
if operation2 == "BOTH" {
|
||
t.WhereRaw(
|
||
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
|
||
fmt.Sprintf("%s%s%s", "%", v.Field(i).Interface(), "%"),
|
||
)
|
||
}
|
||
if operation2 == "LEFT" {
|
||
t.WhereRaw(
|
||
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
|
||
fmt.Sprintf("%s%s", "%", v.Field(i).Interface()),
|
||
)
|
||
}
|
||
if operation2 == "RIGHT" {
|
||
t.WhereRaw(
|
||
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
|
||
fmt.Sprintf("%s%s", v.Field(i).Interface(), "%"),
|
||
)
|
||
}
|
||
}
|
||
})
|
||
})
|
||
|
||
}
|
||
}
|
||
return t
|
||
}
|
||
func (t *Query) SelectAllForReflect(data interface{}) *Query {
|
||
v := reflect.ValueOf(data)
|
||
v = v.Elem()
|
||
for i := 0; i < v.NumField(); i++ {
|
||
column := v.Type().Field(i).Tag.Get("column")
|
||
if column != "" {
|
||
//取第一个添加到列
|
||
split := NewStringSplit(column, ":")
|
||
split.RunIndexFunc(0, func(str string) {
|
||
t.AddSelect(str)
|
||
})
|
||
}
|
||
}
|
||
return t
|
||
}
|