light-pkg/mySql/Query.go
2024-12-18 03:41:30 +08:00

421 lines
9.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mySql
import (
"fmt"
"github.com/Masterminds/squirrel"
"gorm.io/gorm"
"math"
"reflect"
"strconv"
"strings"
)
type Query struct {
tableName string
rawColumns []string
autoColumns []string
rawWheres []*Where
autoWheres []*Where
limit int64
offset int64
groupBy string
orderBy string
join []*Join
isPageMode bool
page int64
pageSize int64
alias string
db *gorm.DB
}
func NewQuery(tableName string, db *gorm.DB) *Query {
return &Query{
tableName: tableName,
rawColumns: []string{},
autoColumns: []string{},
rawWheres: make([]*Where, 0),
autoWheres: make([]*Where, 0),
limit: 0,
offset: 0,
join: make([]*Join, 0),
isPageMode: false,
page: 1,
pageSize: 30,
db: db,
}
}
func (t *Query) SetDB(db *gorm.DB) *Query {
t.db = db
return t
}
func (t *Query) As(alias string) *Query {
t.alias = alias
return t
}
func (t *Query) Select(columns ...string) *Query {
t.autoColumns = []string{}
for _, column := range columns {
t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", column))
}
return t
}
func (t *Query) SelectRow(columns ...string) *Query {
t.rawColumns = columns
return t
}
func (t *Query) AddSelect(columns ...string) *Query {
for _, v := range columns {
t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", v))
}
return t
}
func (t *Query) AddSelectRow(columns ...string) *Query {
for _, v := range columns {
t.rawColumns = append(t.rawColumns, v)
}
return t
}
func (t *Query) GroupBy(s string) *Query {
if find := strings.Contains(s, "."); find {
countSplit := strings.SplitN(s, ".", 2)
if len(countSplit) == 2 {
t.groupBy = fmt.Sprintf("`%s`.`%s`", countSplit[0], countSplit[1])
} else {
t.groupBy = fmt.Sprintf("`%s`", s)
}
} else {
t.groupBy = fmt.Sprintf("`%s`", s)
}
return t
}
func (t *Query) GroupByRaw(s string) *Query {
t.groupBy = s
return t
}
func (t *Query) Limit(s int64) *Query {
t.limit = s
return t
}
func (t *Query) Offset(s int64) *Query {
t.offset = s
return t
}
func (t *Query) Page(p int64) *Query {
t.isPageMode = true
if p == 0 {
p = 1
}
t.page = p
return t
}
func (t *Query) PageSize(p int64) *Query {
if p == 0 {
p = 30
}
t.pageSize = p
return t
}
func (t *Query) Pages(p int64, s int64) *Query {
return t.Page(p).PageSize(s)
}
func (t *Query) MountQuery(f func(squ *Query)) *Query {
f(t)
return t
}
func (t *Query) NotDelete() *Query {
return t.WhereIsNull("deleted_at")
}
func (t *Query) autoColumnsToRaw() []string {
newColumns := make([]string, 0)
for _, v := range t.autoColumns {
if find := strings.Contains(v, "."); !find {
if t.alias != "" {
newColumns = append(newColumns, fmt.Sprintf("`%s`.%s", t.alias, v))
} else {
newColumns = append(newColumns, v)
}
} else {
newColumns = append(newColumns, v)
}
}
return newColumns
}
func (t *Query) getColumns() []string {
all := append(t.autoColumnsToRaw(), t.rawColumns...)
if len(all) == 0 {
return []string{"*"}
}
return all
}
func (t *Query) getFrom() string {
if t.alias != "" {
return fmt.Sprintf("`%s` AS `%s`", t.tableName, t.alias)
}
return fmt.Sprintf("`%s`", t.tableName)
}
func (t *Query) getColName(colName string) string {
if t.alias != "" {
return fmt.Sprintf("`%s`.`%s`", t.alias, colName)
} else {
return fmt.Sprintf("`%s`", colName)
}
}
func (t *Query) mountWhere(squ squirrel.SelectBuilder) squirrel.SelectBuilder {
//先挂载原生数据
for _, where := range t.rawWheres {
squ = squ.Where(where.Formula, where.Values...)
}
//挂载需要加工的数据
for _, where := range t.autoWheres {
squ = squ.Where(where.Formula, where.Values...)
}
return squ
}
func (t *Query) mountOther(squ squirrel.SelectBuilder) squirrel.SelectBuilder {
if t.isPageMode {
squ = squ.Offset(uint64(t.pageSize * (t.page - 1))).Limit(uint64(t.pageSize))
} else {
if t.limit > 0 {
squ = squ.Limit(uint64(t.limit))
}
if t.offset > 0 {
squ = squ.Offset(uint64(t.limit))
}
}
if t.groupBy != "" {
squ = squ.GroupBy(t.groupBy)
}
if t.orderBy != "" {
squ = squ.OrderBy(t.orderBy)
}
return squ
}
func (t *Query) mountJoin(squ squirrel.SelectBuilder) squirrel.SelectBuilder {
for _, join := range t.join {
if join.option == "left" {
if len(join.rest) == 0 {
squ = squ.LeftJoin(join.join)
} else {
squ = squ.LeftJoin(join.join, join.rest)
}
} else if join.option == "right" {
if len(join.rest) == 0 {
squ = squ.RightJoin(join.join)
} else {
squ = squ.RightJoin(join.join, join.rest)
}
} else if join.option == "inner" {
if len(join.rest) == 0 {
squ = squ.RightJoin(join.join)
} else {
squ = squ.RightJoin(join.join, join.rest)
}
} else {
if len(join.rest) == 0 {
squ = squ.RightJoin(join.join)
} else {
squ = squ.RightJoin(join.join, join.rest)
}
}
}
return squ
}
func (t *Query) ToSquirrel() squirrel.SelectBuilder {
//获取select
squ := squirrel.Select(t.getColumns()...).From(t.getFrom())
//组装where
squ = t.mountWhere(squ)
//组装limit和offset,groupBy和orderBy
squ = t.mountOther(squ)
//组装join
squ = t.mountJoin(squ)
return squ
}
func (t *Query) ToGormQuery() *gorm.DB {
return t.toGormQuery(t.ToSquirrel())
}
func (t *Query) ToSql() (string, []interface{}, error) {
return t.ToSquirrel().ToSql()
}
func (t *Query) toGormQuery(squ squirrel.SelectBuilder) *gorm.DB {
return Mapper(squ.ToSql()).setDB(t.db).Query()
}
func (t *Query) Find(list interface{}) error {
return t.ToGormQuery().Find(list).Error
}
func (t *Query) Take(info interface{}) error {
t.limit = 1
return t.ToGormQuery().Take(info).Error
}
func (t *Query) GetTotal() int64 {
//只组装where
squ := squirrel.Select("count(*)").From(t.getFrom())
//组装where
squ = t.mountWhere(squ)
var count int64
t.toGormQuery(squ).Scan(&count)
return count
}
func (t *Query) GetRowCount() int64 {
//只组装where
squ := squirrel.Select("count(*)").From(t.getFrom())
//组装where
squ = t.mountWhere(squ)
//组装groupBy
if t.groupBy != "" {
squ = squ.GroupBy(t.groupBy)
}
var count int64
t.toGormQuery(squ).Count(&count)
return count
}
func (t *Query) GetSum(column string) int64 {
//只组装where
squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom())
//组装where
squ = t.mountWhere(squ)
var sum int64
t.toGormQuery(squ).Scan(&sum)
return sum
}
func (t *Query) GetSumInterface(column string, value interface{}) {
//只组装where
squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom())
//组装where
squ = t.mountWhere(squ)
t.toGormQuery(squ).Scan(&value)
}
func (t *Query) PageFind(list interface{}) (*PageStruct, error) {
err1 := t.Find(list)
total := t.GetTotal()
p := &PageStruct{
Total: total,
TotalPage: 0,
Page: t.page,
PageSize: t.pageSize,
}
totalPage, err := strconv.ParseFloat(fmt.Sprintf("%.5f", float64(total)/float64(t.pageSize)), 64)
if err == nil {
p.TotalPage = int64(math.Ceil(totalPage))
} else {
p.TotalPage = (total / p.PageSize) + 1
}
return p, err1
}
func (t *Query) MountWhereForReflect(data interface{}) *Query {
v := reflect.ValueOf(data)
v = v.Elem()
for i := 0; i < v.NumField(); i++ {
where := v.Type().Field(i).Tag.Get("where")
if where != "" && !isBlank(v.Field(i)) {
split := NewStringSplit(where, ":")
//如果是1段
split.RunCount1Func(func(whereColName string) {
t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), v.Field(i).Interface())
})
//如果是2段
split.RunCount2Func(func(whereColName, whereOperation string) {
//str2的值LIKE_BOTH、LIKE、BOOL等
whereOperationSplit := NewStringSplit(strings.ToUpper(whereOperation), "_")
whereOperationSplit.RunCount1Func(func(operation string) {
//这里的operation是LIKE_BOTH的LIKE
if operation == "LIKE" {
t.WhereRaw(
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
v.Field(i).Interface(),
)
}
if operation == "BOOL" {
if v.Field(i).Interface() == "yes" {
t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 1)
}
if v.Field(i).Interface() == "no" {
t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 0)
}
}
})
whereOperationSplit.RunCount2Func(func(operation, operation2 string) {
//这里的operation是LIKE_BOTH的LIKEoperation2是BOTH
if operation == "LIKE" {
if operation2 == "BOTH" {
t.WhereRaw(
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
fmt.Sprintf("%s%s%s", "%", v.Field(i).Interface(), "%"),
)
}
if operation2 == "LEFT" {
t.WhereRaw(
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
fmt.Sprintf("%s%s", "%", v.Field(i).Interface()),
)
}
if operation2 == "RIGHT" {
t.WhereRaw(
fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)),
fmt.Sprintf("%s%s", v.Field(i).Interface(), "%"),
)
}
}
})
})
}
}
return t
}
func (t *Query) SelectAllForReflect(data interface{}) *Query {
v := reflect.ValueOf(data)
v = v.Elem()
for i := 0; i < v.NumField(); i++ {
column := v.Type().Field(i).Tag.Get("column")
if column != "" {
//取第一个添加到列
split := NewStringSplit(column, ":")
split.RunIndexFunc(0, func(str string) {
t.AddSelect(str)
})
}
}
return t
}