From 4551df34b6aa2b080075e0dad12d718f86aec51e Mon Sep 17 00:00:00 2001 From: Xlxinxi Date: Wed, 18 Dec 2024 23:38:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=88=E5=B9=B6pkg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 63 ++++- pkg/myAliMarket/index.go | 90 +++++++ pkg/myAliSms/index.go | 61 +++++ pkg/myCobra/cobra.go | 50 ++++ pkg/myGorm/gorm.go | 88 +++++++ pkg/myHttp/header.go | 32 +++ pkg/myHttp/index.go | 22 ++ pkg/myHttp/json.go | 57 +++++ pkg/myJwt/jwt.go | 115 +++++++++ pkg/myOss/index.go | 138 +++++++++++ pkg/myPay/alipay.go | 145 ++++++++++++ pkg/myPay/alipay_notify.go | 28 +++ pkg/myPay/wechat.go | 261 ++++++++++++++++++++ pkg/myPay/wechat_notify.go | 15 ++ pkg/myRedis/Ierator.go | 23 ++ pkg/myRedis/IntResult.go | 33 +++ pkg/myRedis/InterfaceResult.go | 25 ++ pkg/myRedis/OperationAttr.go | 52 ++++ pkg/myRedis/SimpleCache.go | 132 +++++++++++ pkg/myRedis/SliceResult.go | 29 +++ pkg/myRedis/StringCache.go | 52 ++++ pkg/myRedis/StringOperation.go | 48 ++++ pkg/myRedis/StringResult.go | 33 +++ pkg/myRedis/cli.go | 33 +++ pkg/myRedis/index.go | 60 +++++ pkg/mySql/Create.go | 9 + pkg/mySql/Delete.go | 55 +++++ pkg/mySql/Facade.go | 32 +++ pkg/mySql/Insert.go | 83 +++++++ pkg/mySql/InsertColumn.go | 10 + pkg/mySql/Join.go | 11 + pkg/mySql/PageStruct.go | 8 + pkg/mySql/Query.go | 420 +++++++++++++++++++++++++++++++++ pkg/mySql/QueryJoin.go | 53 +++++ pkg/mySql/QueryOrderBy.go | 42 ++++ pkg/mySql/QueryWhere.go | 110 +++++++++ pkg/mySql/Reflect.go | 24 ++ pkg/mySql/SqlMapper.go | 78 ++++++ pkg/mySql/StringSplit.go | 56 +++++ pkg/mySql/Update.go | 127 ++++++++++ pkg/mySql/Where.go | 115 +++++++++ pkg/myUrl/index.go | 42 ++++ pkg/myViper/viper.go | 31 +++ 43 files changed, 2978 insertions(+), 13 deletions(-) create mode 100644 pkg/myAliMarket/index.go create mode 100644 pkg/myAliSms/index.go create mode 100644 pkg/myCobra/cobra.go create mode 100644 pkg/myGorm/gorm.go create mode 100644 pkg/myHttp/header.go create mode 100644 pkg/myHttp/index.go create mode 100644 pkg/myHttp/json.go create mode 100644 pkg/myJwt/jwt.go create mode 100644 pkg/myOss/index.go create mode 100644 pkg/myPay/alipay.go create mode 100644 pkg/myPay/alipay_notify.go create mode 100644 pkg/myPay/wechat.go create mode 100644 pkg/myPay/wechat_notify.go create mode 100644 pkg/myRedis/Ierator.go create mode 100644 pkg/myRedis/IntResult.go create mode 100644 pkg/myRedis/InterfaceResult.go create mode 100644 pkg/myRedis/OperationAttr.go create mode 100644 pkg/myRedis/SimpleCache.go create mode 100644 pkg/myRedis/SliceResult.go create mode 100644 pkg/myRedis/StringCache.go create mode 100644 pkg/myRedis/StringOperation.go create mode 100644 pkg/myRedis/StringResult.go create mode 100644 pkg/myRedis/cli.go create mode 100644 pkg/myRedis/index.go create mode 100644 pkg/mySql/Create.go create mode 100644 pkg/mySql/Delete.go create mode 100644 pkg/mySql/Facade.go create mode 100644 pkg/mySql/Insert.go create mode 100644 pkg/mySql/InsertColumn.go create mode 100644 pkg/mySql/Join.go create mode 100644 pkg/mySql/PageStruct.go create mode 100644 pkg/mySql/Query.go create mode 100644 pkg/mySql/QueryJoin.go create mode 100644 pkg/mySql/QueryOrderBy.go create mode 100644 pkg/mySql/QueryWhere.go create mode 100644 pkg/mySql/Reflect.go create mode 100644 pkg/mySql/SqlMapper.go create mode 100644 pkg/mySql/StringSplit.go create mode 100644 pkg/mySql/Update.go create mode 100644 pkg/mySql/Where.go create mode 100644 pkg/myUrl/index.go create mode 100644 pkg/myViper/viper.go diff --git a/go.mod b/go.mod index 87d0710..12c41d5 100644 --- a/go.mod +++ b/go.mod @@ -3,41 +3,78 @@ module code.zhecent.com/gopkg/light-core go 1.23.1 require ( + github.com/Masterminds/squirrel v1.5.4 + github.com/aliyun/alibaba-cloud-sdk-go v1.63.71 + github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible github.com/gin-gonic/gin v1.10.0 + github.com/go-pay/gopay v1.5.106 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/redis/go-redis/v9 v9.7.0 github.com/shopspring/decimal v1.4.0 + github.com/spf13/cobra v1.8.1 + github.com/spf13/viper v1.19.0 golang.org/x/crypto v0.31.0 google.golang.org/protobuf v1.36.0 + gorm.io/driver/mysql v1.5.7 + gorm.io/gorm v1.25.12 ) require ( - github.com/bytedance/sonic v1.12.2 // indirect - github.com/bytedance/sonic/loader v0.2.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/bytedance/sonic v1.12.6 // indirect + github.com/bytedance/sonic/loader v0.2.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/gabriel-vasile/mimetype v1.4.5 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/fsnotify/fsnotify v1.8.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.7 // indirect github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-pay/crypto v0.0.1 // indirect + github.com/go-pay/errgroup v0.0.2 // indirect + github.com/go-pay/util v0.0.4 // indirect + github.com/go-pay/xlog v0.0.3 // indirect + github.com/go-pay/xtime v0.0.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.22.0 // indirect - github.com/goccy/go-json v0.10.3 // indirect + github.com/go-playground/validator/v10 v10.23.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/goccy/go-json v0.10.4 // indirect github.com/google/go-cmp v0.6.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.8 // indirect - github.com/kr/text v0.2.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect + github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/magiconair/properties v1.8.9 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect + github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sagikazarmark/locafero v0.6.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect - golang.org/x/arch v0.9.0 // indirect - golang.org/x/net v0.28.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/arch v0.12.0 // indirect + golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 // indirect + golang.org/x/net v0.32.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect - gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect + golang.org/x/time v0.8.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/myAliMarket/index.go b/pkg/myAliMarket/index.go new file mode 100644 index 0000000..58f5cba --- /dev/null +++ b/pkg/myAliMarket/index.go @@ -0,0 +1,90 @@ +package myAliMarket + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" +) + +type Client struct { + url string + appCode string + param map[string]string + fullUrl string + response string +} + +func New(url string, appCode string) *Client { + return &Client{ + url: url, + appCode: appCode, + } +} + +func (t *Client) SetParam(param map[string]string) { + t.param = param +} + +func (t *Client) getFullUrl() (string, error) { + u, err := url.Parse(t.url) + if err != nil { + return "", err + } + q := u.Query() + for k, v := range t.param { + q.Add(k, v) + } + u.RawQuery = q.Encode() + t.fullUrl = u.String() + return t.fullUrl, nil +} + +func (t *Client) GetFullUrl() string { + return t.fullUrl +} + +func (t *Client) GetResponse() string { + return t.response +} + +func (t *Client) GetRequest(respData interface{}) error { + fullUrl, err := t.getFullUrl() + if err != nil { + return err + } + + client := &http.Client{} + request, err := http.NewRequest("GET", fullUrl, nil) + if err != nil { + return errors.New("http客户端初始化失败") + } + request.Header.Add("Authorization", fmt.Sprintf("APPCODE %s", t.appCode)) + + response, err := client.Do(request) + if err != nil { + return errors.New("认证服务器连接失败1") + } + + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return errors.New("认证服务器连接失败2") + } + + var response2 []byte + response2, err = ioutil.ReadAll(response.Body) + if err != nil { + return errors.New("数据解析失败") + } + + t.response = string(response2) + + err = json.Unmarshal(response2, respData) + if err != nil { + return err + } + return nil +} diff --git a/pkg/myAliSms/index.go b/pkg/myAliSms/index.go new file mode 100644 index 0000000..dea07fa --- /dev/null +++ b/pkg/myAliSms/index.go @@ -0,0 +1,61 @@ +package myAliSms + +import ( + "encoding/json" + "errors" + "github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi" +) + +type Client struct { + accessKeyId string + accessSecret string + signName string + templateCode string +} + +func NewClient(accessKeyId string, accessSecret string) *Client { + return &Client{accessKeyId: accessKeyId, accessSecret: accessSecret} +} + +func (t *Client) SetSignName(signName string) *Client { + t.signName = signName + return t +} + +func (t *Client) SetTemplateCode(code string) *Client { + t.templateCode = code + return t +} + +func (t *Client) SendSms(m map[string]interface{}, phone string) error { + if t.signName == "" { + return errors.New("签名不能为空") + } + client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", t.accessKeyId, t.accessSecret) + if err != nil { + return err + } + + request := dysmsapi.CreateSendSmsRequest() + request.Scheme = "https" + request.PhoneNumbers = phone + request.SignName = t.signName + request.TemplateCode = t.templateCode + request.TemplateParam = t.mapToJson(m) + + response, err := client.SendSms(request) + if err != nil { + return err + } else { + if response.Code == "OK" { + return nil + } else { + return errors.New(response.Message) + } + } +} + +func (t *Client) mapToJson(TemplateParamMap map[string]interface{}) string { + mjson, _ := json.Marshal(TemplateParamMap) + return string(mjson) +} diff --git a/pkg/myCobra/cobra.go b/pkg/myCobra/cobra.go new file mode 100644 index 0000000..7f349bb --- /dev/null +++ b/pkg/myCobra/cobra.go @@ -0,0 +1,50 @@ +package myCobra + +import ( + "github.com/spf13/cobra" +) + +type SimpleCmd struct { + Use string + Short string + Example string + PreRun func() + RunE func() error + cobraModel *cobra.Command +} + +func (t *SimpleCmd) GetCobra() *cobra.Command { + if t.cobraModel == nil { + t.cobraModel = &cobra.Command{ + Use: t.Use, + Short: t.Short, + Example: t.Example, + SilenceUsage: true, + PreRun: func(cmd *cobra.Command, args []string) { + t.PreRun() + }, + RunE: func(cmd *cobra.Command, args []string) error { + return t.RunE() + }, + } + } + return t.cobraModel +} + +func (t *SimpleCmd) SetArgsFunc(argsFunc func(args []string) error) { + t.GetCobra().Args = func(cmd *cobra.Command, args []string) error { + return argsFunc(args) + } +} + +func (t *SimpleCmd) SetStringVar(p *string, name, shorthand string, value string, usage string) { + t.GetCobra().PersistentFlags().StringVarP(p, name, shorthand, value, usage) +} + +func (t *SimpleCmd) AddCommand(cmd *SimpleCmd) { + t.GetCobra().AddCommand(cmd.GetCobra()) +} + +func (t *SimpleCmd) Execute() error { + return t.GetCobra().Execute() +} diff --git a/pkg/myGorm/gorm.go b/pkg/myGorm/gorm.go new file mode 100644 index 0000000..ce5ff6b --- /dev/null +++ b/pkg/myGorm/gorm.go @@ -0,0 +1,88 @@ +package myGorm + +import ( + "fmt" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "log" + "os" + "strings" + "time" +) + +type SimpleORM struct { + user string + password string + host string + name string + charset string + runMode string + loggerConfig logger.Config + loggerLevel string +} + +func NewSimpleORM(user string, password string, host string, name string, charset string, runMode string) *SimpleORM { + orm := &SimpleORM{ + user: user, + password: password, + host: host, + name: name, + charset: charset, + runMode: runMode, + loggerConfig: logger.Config{ + SlowThreshold: time.Second, // 慢 SQL 阈值 + Colorful: false, // 禁用彩色打印 + LogLevel: logger.Info, + }, + } + if runMode == "release" { + orm.loggerConfig.LogLevel = logger.Silent + } + + return orm +} + +func (t *SimpleORM) getDSN() string { + if t.charset == "" { + t.charset = "utf8mb4" + } + return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&parseTime=True&loc=Local", + t.user, + t.password, + t.host, + t.name, + t.charset, + ) +} + +func (t *SimpleORM) SetLoggerConfig(l logger.Config) { + t.loggerConfig = l +} + +func (t *SimpleORM) SetLoggerLevel(level string) { + if strings.ToLower(level) == "silent" { + t.loggerConfig.LogLevel = logger.Silent + } else if strings.ToLower(level) == "error" { + t.loggerConfig.LogLevel = logger.Error + } else if strings.ToLower(level) == "warn" { + t.loggerConfig.LogLevel = logger.Warn + } else if strings.ToLower(level) == "info" { + t.loggerConfig.LogLevel = logger.Info + } +} + +func (t *SimpleORM) ConnectMysql() *gorm.DB { + Db, err := gorm.Open( + mysql.Open(t.getDSN()), &gorm.Config{ + Logger: logger.New( + log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer + t.loggerConfig, + ), + }, + ) + if err != nil { + panic("failed to connect database") + } + return Db +} diff --git a/pkg/myHttp/header.go b/pkg/myHttp/header.go new file mode 100644 index 0000000..191f9c3 --- /dev/null +++ b/pkg/myHttp/header.go @@ -0,0 +1,32 @@ +package myHttp + +func (t *Client) HasHeader(key string) bool { + for k, _ := range t.headers { + if key == k { + return true + } + } + return false +} + +func (t *Client) SetHeaders(header map[string]string) *Client { + t.headers = header + return t +} + +func (t *Client) AddHeader(key, value string) *Client { + t.headers[key] = value + return t +} + +func (t *Client) AddHeaders(header map[string]string) *Client { + for k, v := range header { + t.headers[k] = v + } + return t +} + +func (t *Client) DelHeader(key string) *Client { + delete(t.headers, key) + return t +} diff --git a/pkg/myHttp/index.go b/pkg/myHttp/index.go new file mode 100644 index 0000000..97af464 --- /dev/null +++ b/pkg/myHttp/index.go @@ -0,0 +1,22 @@ +package myHttp + +import ( + "code.zhecent.com/gopkg/light-core/pkg/myUrl" +) + +type Client struct { + url *myUrl.UrlCli + headers map[string]string +} + +func NewClient(url string) (*Client, error) { + urlObj, err := myUrl.NewUrlCliWithParse(url) + if err != nil { + return nil, err + } + return &Client{url: urlObj}, nil +} + +func (t *Client) GetMyUrl() *myUrl.UrlCli { + return t.url +} diff --git a/pkg/myHttp/json.go b/pkg/myHttp/json.go new file mode 100644 index 0000000..3b96c13 --- /dev/null +++ b/pkg/myHttp/json.go @@ -0,0 +1,57 @@ +package myHttp + +import ( + "bytes" + "encoding/json" + "errors" + "net/http" +) + +func (t *Client) PostJson(request interface{}, response interface{}) error { + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + encoder.SetEscapeHTML(false) + if err := encoder.Encode(request); err != nil { + return err + } + + // HTTP请求 + req, err := http.NewRequest("POST", t.url.String(), &buf) + if err != nil { + return err + } + + //添加请求头 + req.Header.Set("Content-Type", "application/json; charset=utf-8") + if len(t.headers) > 0 { + for key, value := range t.headers { + req.Header.Set(key, value) + } + } + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("http.Status:" + resp.Status) + } + return json.NewDecoder(resp.Body).Decode(response) +} + +func (t *Client) GetJson(response interface{}) error { + httpResp, err := http.Get(t.url.String()) + if err != nil { + return err + } + defer httpResp.Body.Close() + + if httpResp.StatusCode != http.StatusOK { + return errors.New("http.Status:" + httpResp.Status) + } + return json.NewDecoder(httpResp.Body).Decode(response) +} diff --git a/pkg/myJwt/jwt.go b/pkg/myJwt/jwt.go new file mode 100644 index 0000000..801eeec --- /dev/null +++ b/pkg/myJwt/jwt.go @@ -0,0 +1,115 @@ +package myJwt + +import ( + "errors" + "github.com/golang-jwt/jwt/v5" + "time" +) + +type Claims struct { + UserId int `json:"uid"` + jwt.RegisteredClaims +} + +func (t Claims) toSimpleJwt() *SimpleJwt { + return &SimpleJwt{ + Audience: t.Audience, + ExpiresAt: t.ExpiresAt, + Id: t.ID, + IssuedAt: t.IssuedAt, + Issuer: t.Issuer, + NotBefore: t.NotBefore, + Subject: t.Subject, + userId: t.UserId, + jwtSecret: "", + } + +} + +type SimpleJwt struct { + Audience jwt.ClaimStrings + ExpiresAt *jwt.NumericDate + Id string + IssuedAt *jwt.NumericDate + Issuer string + NotBefore *jwt.NumericDate + Subject string + + userId int + jwtSecret string +} + +func NewSimpleJwt(userId int, jwtSecret string) *SimpleJwt { + return &SimpleJwt{ + userId: userId, + jwtSecret: jwtSecret, + + ExpiresAt: jwt.NewNumericDate(time.Now().Add(90 * 24 * time.Hour)), + Issuer: "zhecent", + } +} + +func (t *SimpleJwt) SetExpiresAt(expiresAt time.Time) { + t.ExpiresAt = jwt.NewNumericDate(expiresAt) +} + +func (t *SimpleJwt) makeClaims() Claims { + return Claims{ + UserId: t.userId, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: t.ExpiresAt, + Issuer: t.Issuer, + }, + } +} + +func (t *SimpleJwt) GetUserId() int { + return t.userId +} + +func (t *SimpleJwt) GenerateToken() (string, error) { + tokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, t.makeClaims()) + token, err := tokenClaims.SignedString([]byte(t.jwtSecret)) + return token, err +} + +func (t *SimpleJwt) MustGenerateToken() string { + s, e := t.GenerateToken() + if e != nil { + panic(e.Error()) + } + return s +} + +func ParseToken(token string, jwtSecret string) (*SimpleJwt, error) { + tokenClaims, err := jwt.ParseWithClaims(token, &Claims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(jwtSecret), nil + }) + + if tokenClaims != nil { + if claims, ok := tokenClaims.Claims.(*Claims); ok && tokenClaims.Valid { + return claims.toSimpleJwt(), nil + } + } + + switch { + case errors.Is(err, jwt.ErrTokenMalformed): + return nil, errors.New("that's not even a token") + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + // Invalid signature + return nil, errors.New("invalid signature") + case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet): + // Token is either expired or not active yet + return nil, errors.New("timing is everything") + default: + return nil, errors.New("token校验失败") + } +} + +func MustParseToken(token string, jwtSecret string) *SimpleJwt { + s, e := ParseToken(token, jwtSecret) + if e != nil { + panic(e.Error()) + } + return s +} diff --git a/pkg/myOss/index.go b/pkg/myOss/index.go new file mode 100644 index 0000000..08a894f --- /dev/null +++ b/pkg/myOss/index.go @@ -0,0 +1,138 @@ +package myOss + +import ( + "fmt" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "io" + "net/http" + "strings" +) + +type Model struct { + Path string + Bucket string + ACLType string + Endpoint string + AccessKeyID string + SecretAccessKey string + StorageClassType string + CdnUrl string + ossClient *oss.Client + ossBucket *oss.Bucket +} + +func (t *Model) MustGetOssClient() *oss.Client { + if t.ossClient == nil { + client, err := oss.New(t.Endpoint, t.AccessKeyID, t.SecretAccessKey, oss.Timeout(10, 120)) + if err != nil { + panic("function oss.New() Filed, err:" + err.Error()) + } + t.ossClient = client + } + return t.ossClient +} + +func (t *Model) MustGetDefaultOssBucket() *oss.Bucket { + if t.ossBucket == nil { + client := t.MustGetOssClient() + bucket, err := client.Bucket(t.Bucket) + + if err != nil { + panic("function client.Bucket() Filed, err:" + err.Error()) + } + t.ossBucket = bucket + } + return t.ossBucket +} + +func (t *Model) PutObject(savePath, fileName string, reader io.Reader, contentType string) (*RespPath, error) { + err := t.MustGetDefaultOssBucket().PutObject( + fmt.Sprintf("%s/%s", savePath, fileName), + reader, + oss.ContentType(contentType), + t.GetObjectStorageClass(), + t.GetObjectAcl(), + ) + return &RespPath{ + Path: savePath, + Name: fileName, + Host: t.CdnUrl, + }, err + +} + +func (t *Model) MustPutObject(savePath, fileName string, reader io.Reader, contentType string) *RespPath { + p, err := t.PutObject(savePath, fileName, reader, contentType) + if err != nil { + panic(err.Error()) + } + return p +} + +func (t *Model) GetObjectMeta(key string) (http.Header, error) { + if strings.HasPrefix(key, "/") { + key = strings.TrimLeft(key, "/") + } + return t.MustGetDefaultOssBucket().GetObjectMeta(key) +} + +func (t *Model) DeleteObject(key string) error { + // 删除单个文件。objectName表示删除OSS文件时需要指定包含文件后缀在内的完整路径,例如abc/efg/123.jpg。 + // 如需删除文件夹,请将objectName设置为对应的文件夹名称。如果文件夹非空,则需要将文件夹下的所有object删除后才能删除该文件夹。 + //这里需要处理一下路径,如果第一个是/,要去掉 + + if strings.HasPrefix(key, "/") { + key = strings.TrimLeft(key, "/") + } + if t.IsFileExist(key) { + return t.MustGetDefaultOssBucket().DeleteObject(key) + } + return fmt.Errorf("文件不存在") +} + +func (t *Model) IsFileExist(key string) bool { + if strings.HasPrefix(key, "/") { + key = strings.TrimLeft(key, "/") + } + exist, err := t.MustGetDefaultOssBucket().IsObjectExist(key) + if err != nil { + panic(err.Error()) + } + return exist +} + +func (t *Model) GetObjectStorageClass() oss.Option { + switch t.StorageClassType { // 根据配置文件进行指定存储类型 + case "Standard": // 指定存储类型为标准存储 + return oss.ObjectStorageClass(oss.StorageStandard) + case "IA": // 指定存储类型为很少访问存储 + return oss.ObjectStorageClass(oss.StorageIA) + case "Archive": // 指定存储类型为归档存储。 + return oss.ObjectStorageClass(oss.StorageArchive) + case "ColdArchive": // 指定存储类型为归档存储。 + return oss.ObjectStorageClass(oss.StorageColdArchive) + default: // 无匹配结果就是标准存储 + return oss.ObjectStorageClass(oss.StorageStandard) + } +} + +func (t *Model) GetObjectAcl() oss.Option { + switch t.ACLType { // 根据配置文件进行指定访问权限 + case "private": // 指定访问权限为私有读写 + return oss.ObjectACL(oss.ACLPrivate) // 指定访问权限为公共读 + case "public-read": + return oss.ObjectACL(oss.ACLPublicRead) // 指定访问权限为公共读 + case "public-read-write": + return oss.ObjectACL(oss.ACLPublicReadWrite) // 指定访问权限为公共读写 + case "default": + return oss.ObjectACL(oss.ACLDefault) // 指定访问权限为公共读 + default: + return oss.ObjectACL(oss.ACLPrivate) // 默认为访问权限为公共读 + } +} + +type RespPath struct { + Path string + Name string + Host string +} diff --git a/pkg/myPay/alipay.go b/pkg/myPay/alipay.go new file mode 100644 index 0000000..be0f9db --- /dev/null +++ b/pkg/myPay/alipay.go @@ -0,0 +1,145 @@ +package myPay + +import ( + "context" + "errors" + "github.com/go-pay/gopay" + "github.com/go-pay/gopay/alipay" + "github.com/shopspring/decimal" + "net/http" +) + +type AliPay struct { + ctx context.Context + cli *alipay.Client + params *AliPayImpl + orderNo string + amount int + description string + notifyUrl string + returnUrl string +} + +func NewAliPay(params *AliPayImpl) *AliPay { + client, err := alipay.NewClient(params.AppId, params.PrivateKey, params.IsProd) + if err != nil { + panic(err.Error()) + } + + // 打开Debug开关,输出日志,默认是关闭的 + if params.IsProd { + client.DebugSwitch = gopay.DebugOff + } else { + client.DebugSwitch = gopay.DebugOn + } + + client.SetLocation(alipay.LocationShanghai). + SetCharset(alipay.UTF8). + SetSignType(alipay.RSA2). + AutoVerifySign([]byte(params.PublicKey)) + + if err := client.SetCertSnByContent([]byte(params.AppCertContent), []byte(params.AliPayRootCertContent), []byte(params.PublicKey)); err != nil { + panic(err.Error()) + } + + return &AliPay{params: params, cli: client, ctx: context.Background()} +} + +func (t *AliPay) GetCli() *alipay.Client { + return t.cli +} + +func (t *AliPay) SetCommonConfig(notifyUrl string, returnUrl string) *AliPay { + t.notifyUrl = notifyUrl + t.returnUrl = returnUrl + return t +} + +func (t *AliPay) SetOrderInfo(orderNo string, amount int, description string) *AliPay { + t.orderNo = orderNo + t.amount = amount + t.description = description + return t +} + +func (t *AliPay) SetCertSnByContent(appCertContent, aliPayRootCertContent []byte) error { + return t.cli.SetCertSnByContent(appCertContent, aliPayRootCertContent, []byte(t.params.PublicKey)) +} + +func (t *AliPay) SetCertSnByPath(appCertPath, aliPayRootCertPath, aliPayPublicCertPath string) error { + return t.cli.SetCertSnByPath(appCertPath, aliPayRootCertPath, aliPayPublicCertPath) +} + +func (t *AliPay) setBodyMap() gopay.BodyMap { + if t.description == "" || t.orderNo == "" || t.notifyUrl == "" { + panic("param is empty") + } + if t.amount == 0 { + panic("amount is zero") + } + + // 配置公共参数 + t.cli.SetNotifyUrl(t.notifyUrl). + SetReturnUrl(t.returnUrl) + + bm := make(gopay.BodyMap) + bm.Set("subject", t.description) + bm.Set("out_trade_no", t.orderNo) + bm.Set("total_amount", t.fen2Yuan(uint64(t.amount))) + return bm +} + +func (t *AliPay) GetWapPay(quitUrl string) (string, error) { + bm := t.setBodyMap() + bm.Set("quit_url", quitUrl) + bm.Set("product_code", "QUICK_WAP_WAY") + return t.cli.TradeWapPay(t.ctx, bm) +} + +func (t *AliPay) GetPagePay() (string, error) { + bm := t.setBodyMap() + bm.Set("product_code", "FAST_INSTANT_TRADE_PAY") + return t.cli.TradePagePay(t.ctx, bm) +} + +func (t *AliPay) GetAppPay() (string, error) { + bm := t.setBodyMap() + bm.Set("product_code", "QUICK_MSECURITY_PAY") + return t.cli.TradeAppPay(t.ctx, bm) +} + +func (t *AliPay) Notify(req *http.Request) (*AlipayNotifyResp, error) { + notifyReq, err := alipay.ParseNotifyToBodyMap(req) + if err != nil { + return nil, errors.New("解析回调失败") + } + + if ok, err := alipay.VerifySign(t.params.PublicKey, notifyReq); err != nil || ok == false { + return nil, errors.New("sign Error") + } + + return &AlipayNotifyResp{resp: notifyReq}, nil +} + +func (t *AliPay) GetOrderNo() string { + return t.orderNo +} + +func (t *AliPay) GetAmount() int { + return t.amount +} + +type AliPayImpl struct { + AppId string + PublicKey string + PrivateKey string + IsProd bool + AppCertContent string + AliPayRootCertContent string +} + +func (t *AliPay) fen2Yuan(price uint64) string { + d := decimal.New(1, 2) + result := decimal.NewFromInt(int64(price)).DivRound(d, 2).String() + return result +} diff --git a/pkg/myPay/alipay_notify.go b/pkg/myPay/alipay_notify.go new file mode 100644 index 0000000..7cb1da5 --- /dev/null +++ b/pkg/myPay/alipay_notify.go @@ -0,0 +1,28 @@ +package myPay + +import "github.com/go-pay/gopay" + +type AlipayNotifyResp struct { + resp gopay.BodyMap +} +type AlipayNotifyRespInfo struct { + TradeStatus string + OutTradeNo string + SellerId string + TradeNo string + GmtPayment string +} + +func (t *AlipayNotifyResp) IsSuccess() bool { + return t.resp.Get("trade_status") == "TRADE_SUCCESS" +} + +func (t *AlipayNotifyResp) GetResult() *AlipayNotifyRespInfo { + return &AlipayNotifyRespInfo{ + TradeStatus: t.resp.Get("trade_status"), + OutTradeNo: t.resp.Get("out_trade_no"), + SellerId: t.resp.Get("seller_id"), + TradeNo: t.resp.Get("trade_no"), + GmtPayment: t.resp.Get("gmt_payment"), + } +} diff --git a/pkg/myPay/wechat.go b/pkg/myPay/wechat.go new file mode 100644 index 0000000..8156fe8 --- /dev/null +++ b/pkg/myPay/wechat.go @@ -0,0 +1,261 @@ +package myPay + +import ( + "context" + "errors" + "github.com/go-pay/gopay" + "github.com/go-pay/gopay/wechat/v3" + "net/http" + "time" +) + +type Wechat struct { + ctx context.Context + cli *wechat.ClientV3 + params *WechatPayV3Impl + orderNo string + amount int + description string + notifyUrl string + profitSharing bool +} + +func NewWechat(params *WechatPayV3Impl) *Wechat { + // NewClientV3 初始化微信客户端 v3 + // mchid:商户ID 或者服务商模式的 sp_mchid + // serialNo:商户证书的证书序列号 + // apiV3Key:apiV3Key,商户平台获取 + // privateKey:私钥 apiclient_key.pem 读取后的内容 + client, err := wechat.NewClientV3(params.MchId, params.SerialNo, params.ApiV3Key, params.PKContent) + if err != nil { + panic(err.Error()) + } + + // 设置微信平台API证书和序列号(推荐开启自动验签,无需手动设置证书公钥等信息) + //client.SetPlatformCert([]byte(""), "") + + // 启用自动同步返回验签,并定时更新微信平台API证书(开启自动验签时,无需单独设置微信平台API证书和序列号) + client.SetPlatformCert([]byte(params.WxPkContent), params.WxPkSerialNo) + // 启用自动同步返回验签,并定时更新微信平台API证书 + //err = client.AutoVerifySign() + //if err != nil { + // return nil, err + //} + + // 打开Debug开关,输出日志,默认是关闭的 + if params.IsProd { + client.DebugSwitch = gopay.DebugOff + } else { + client.DebugSwitch = gopay.DebugOn + } + + return &Wechat{params: params, cli: client, ctx: context.Background()} +} + +func (t *Wechat) GetCli() *wechat.ClientV3 { + return t.cli +} + +func (t *Wechat) SetCommonConfig(notifyUrl string) *Wechat { + t.notifyUrl = notifyUrl + return t +} + +func (t *Wechat) SetOrderInfo(orderNo string, amount int, description string) *Wechat { + t.orderNo = orderNo + t.amount = amount + t.description = description + return t +} + +func (t *Wechat) setBodyMap() gopay.BodyMap { + if t.description == "" || t.orderNo == "" || t.notifyUrl == "" { + panic("param is empty") + } + if t.amount == 0 { + panic("amount is zero") + } + + bm := make(gopay.BodyMap) + bm.Set("description", t.description). + Set("out_trade_no", t.orderNo). + Set("time_expire", time.Now().Add(10*time.Minute).Format(time.RFC3339)). + Set("notify_url", t.notifyUrl). + SetBodyMap("amount", func(bm gopay.BodyMap) { + bm.Set("total", t.amount).Set("currency", "CNY") + }) + if t.profitSharing { + bm.SetBodyMap("settle_info", func(bm gopay.BodyMap) { + bm.Set("profit_sharing", true) + }) + } + + return bm +} + +func (t *Wechat) SetProfitSharing(b bool) *Wechat { + t.profitSharing = b + return t +} + +func (t *Wechat) GetApp() (*wechat.AppPayParams, error) { + wxRsp, err := t.cli.V3TransactionApp(t.ctx, t.setBodyMap().Set("appid", t.params.AppAppid)) + if err != nil { + return nil, err + } + + if wxRsp.Code == wechat.Success { + //校验签名 + if err2 := wechat.V3VerifySignByPK(wxRsp.SignInfo.HeaderTimestamp, wxRsp.SignInfo.HeaderNonce, wxRsp.SignInfo.SignBody, wxRsp.SignInfo.HeaderSignature, t.cli.WxPublicKey()); err != nil { + return nil, err2 + } + + //获取调起参数 + return t.cli.PaySignOfApp(t.params.AppAppid, wxRsp.Response.PrepayId) + } else { + return nil, errors.New(wxRsp.Error) + } +} + +func (t *Wechat) GetJsapi(openId string) (*wechat.JSAPIPayParams, error) { + wxRsp, err := t.cli.V3TransactionJsapi(t.ctx, + t.setBodyMap().Set("appid", t.params.MpAppid).SetBodyMap("payer", func(bm gopay.BodyMap) { + bm.Set("openid", openId) + }), + ) + if err != nil { + return nil, err + } + if wxRsp.Code == wechat.Success { + //校验签名 + if err2 := wechat.V3VerifySignByPK(wxRsp.SignInfo.HeaderTimestamp, wxRsp.SignInfo.HeaderNonce, wxRsp.SignInfo.SignBody, wxRsp.SignInfo.HeaderSignature, t.cli.WxPublicKey()); err != nil { + return nil, err2 + } + + //获取调起参数 + return t.cli.PaySignOfJSAPI(t.params.MpAppid, wxRsp.Response.PrepayId) + } else { + return nil, errors.New(wxRsp.Error) + } +} + +func (t *Wechat) GetJsapiForMini(openId string) (*wechat.JSAPIPayParams, error) { + wxRsp, err := t.cli.V3TransactionJsapi(t.ctx, + t.setBodyMap().Set("appid", t.params.MiniAppid).SetBodyMap("payer", func(bm gopay.BodyMap) { + bm.Set("openid", openId) + }), + ) + if err != nil { + return nil, err + } + + if wxRsp.Code == wechat.Success { + //校验签名 + if err2 := wechat.V3VerifySignByPK(wxRsp.SignInfo.HeaderTimestamp, wxRsp.SignInfo.HeaderNonce, wxRsp.SignInfo.SignBody, wxRsp.SignInfo.HeaderSignature, t.cli.WxPublicKey()); err != nil { + return nil, err2 + } + + //获取调起参数 + return t.cli.PaySignOfJSAPI(t.params.MiniAppid, wxRsp.Response.PrepayId) + } else { + return nil, errors.New(wxRsp.Error) + } +} + +func (t *Wechat) GetNative() (*wechat.Native, error) { + wxRsp, err := t.cli.V3TransactionNative(t.ctx, t.setBodyMap().Set("appid", t.params.MpAppid)) + if err != nil { + return nil, err + } + + if wxRsp.Code == wechat.Success { + //校验签名 + if err2 := wechat.V3VerifySignByPK(wxRsp.SignInfo.HeaderTimestamp, wxRsp.SignInfo.HeaderNonce, wxRsp.SignInfo.SignBody, wxRsp.SignInfo.HeaderSignature, t.cli.WxPublicKey()); err != nil { + return nil, err2 + } + + //获取调起参数 + return wxRsp.Response, err + } else { + return nil, errors.New(wxRsp.Error) + } +} + +func (t *Wechat) GetH5(ip string, appName string, appUrl string) (*wechat.H5Url, error) { + wxRsp, err := t.cli.V3TransactionH5(t.ctx, t.setBodyMap().Set("appid", t.params.MpAppid).SetBodyMap("scene_info", func(bm gopay.BodyMap) { + bm.Set("payer_client_ip", ip) + bm.SetBodyMap("h5_info", func(bm gopay.BodyMap) { + bm.Set("type", "Wap") + bm.Set("app_url", appUrl) + bm.Set("app_name", appName) + }) + })) + + if err != nil { + return nil, err + } + + if wxRsp.Code == wechat.Success { + //校验签名 + if err2 := wechat.V3VerifySignByPK(wxRsp.SignInfo.HeaderTimestamp, wxRsp.SignInfo.HeaderNonce, wxRsp.SignInfo.SignBody, wxRsp.SignInfo.HeaderSignature, t.cli.WxPublicKey()); err != nil { + return nil, err2 + } + + //获取调起参数 + return wxRsp.Response, nil + } else { + return nil, errors.New(wxRsp.Error) + } +} + +func (t *Wechat) Notify(req *http.Request) (*WechatNotifyResp, error) { + notifyReq, err := wechat.V3ParseNotify(req) + if err != nil { + return nil, errors.New("解析回调失败") + } + + err = notifyReq.VerifySignByPK(t.cli.WxPublicKey()) + if err != nil { + return nil, errors.New("sign Error") + } + + if notifyReq.EventType == "TRANSACTION.SUCCESS" { + result, err := notifyReq.DecryptPayCipherText(t.params.ApiV3Key) + if err != nil { + return nil, errors.New("解密错误") + } else { + return &WechatNotifyResp{resp: result}, nil + } + } + + return nil, errors.New(notifyReq.EventType) +} + +func (t *Wechat) GetOrderNo() string { + return t.orderNo +} + +func (t *Wechat) GetAmount() int { + return t.amount +} + +func NotifySuccess(msg string) (int, *wechat.V3NotifyRsp) { + return http.StatusOK, &wechat.V3NotifyRsp{Code: gopay.SUCCESS, Message: msg} +} + +func NotifyFail(msg string) (int, *wechat.V3NotifyRsp) { + return http.StatusBadRequest, &wechat.V3NotifyRsp{Code: gopay.FAIL, Message: msg} +} + +type WechatPayV3Impl struct { + MpAppid string + AppAppid string + MiniAppid string + MchId string + ApiV3Key string + SerialNo string + PKContent string + WxPkSerialNo string + WxPkContent string + IsProd bool +} diff --git a/pkg/myPay/wechat_notify.go b/pkg/myPay/wechat_notify.go new file mode 100644 index 0000000..3e288bd --- /dev/null +++ b/pkg/myPay/wechat_notify.go @@ -0,0 +1,15 @@ +package myPay + +import "github.com/go-pay/gopay/wechat/v3" + +type WechatNotifyResp struct { + resp *wechat.V3DecryptPayResult +} + +func (t *WechatNotifyResp) IsSuccess() bool { + return t.resp.TradeState == "SUCCESS" +} + +func (t *WechatNotifyResp) GetResult() *wechat.V3DecryptPayResult { + return t.resp +} diff --git a/pkg/myRedis/Ierator.go b/pkg/myRedis/Ierator.go new file mode 100644 index 0000000..c069263 --- /dev/null +++ b/pkg/myRedis/Ierator.go @@ -0,0 +1,23 @@ +package myRedis + +type Iterator struct { + data []interface{} + index int +} + +func NewIterator(data []interface{}) *Iterator { + return &Iterator{data: data} +} + +func (t *Iterator) HasNext() bool { + if t.data == nil || len(t.data) == 0 { + return false + } + return t.index < len(t.data) +} + +func (t *Iterator) Next() (ret interface{}) { + ret = t.data[t.index] + t.index = t.index + 1 + return +} diff --git a/pkg/myRedis/IntResult.go b/pkg/myRedis/IntResult.go new file mode 100644 index 0000000..6d5ed83 --- /dev/null +++ b/pkg/myRedis/IntResult.go @@ -0,0 +1,33 @@ +package myRedis + +type IntResult struct { + Result int64 + Err error +} + +func NewIntResult(result int64, err error) *IntResult { + return &IntResult{Result: result, Err: err} +} + +func (t *IntResult) Unwrap() int64 { + if t.Err != nil { + panic(t.Err) + } + + return t.Result +} + +func (t *IntResult) UnwrapOr(str int64) int64 { + if t.Err != nil { + return str + } else { + return t.Result + } +} + +func (t *IntResult) UnwrapOrElse(f func() int64) int64 { + if t.Err != nil { + return f() + } + return t.Result +} diff --git a/pkg/myRedis/InterfaceResult.go b/pkg/myRedis/InterfaceResult.go new file mode 100644 index 0000000..4a2aef0 --- /dev/null +++ b/pkg/myRedis/InterfaceResult.go @@ -0,0 +1,25 @@ +package myRedis + +type InterfaceResult struct { + Result interface{} + Err error +} + +func NewInterfaceResult(result interface{}, err error) *InterfaceResult { + return &InterfaceResult{Result: result, Err: err} +} + +func (t *InterfaceResult) Unwrap() interface{} { + if t.Err != nil { + panic(t.Err) + } + + return t.Result +} + +func (t *InterfaceResult) UnwrapOr(a interface{}) interface{} { + if t.Err != nil { + return a + } + return t.Result +} diff --git a/pkg/myRedis/OperationAttr.go b/pkg/myRedis/OperationAttr.go new file mode 100644 index 0000000..4e2e064 --- /dev/null +++ b/pkg/myRedis/OperationAttr.go @@ -0,0 +1,52 @@ +package myRedis + +import ( + "fmt" + "time" +) + +const ( + AttrExpr = "expr" + AttrNx = "nx" + AttrXx = "xx" +) + +type empty struct { +} + +type OperationAttr struct { + Name string + Value interface{} +} + +type OperationAttrs []*OperationAttr + +func (t OperationAttrs) Find(name string) *InterfaceResult { + for _, attr := range t { + if attr.Name == name { + return NewInterfaceResult(attr.Value, nil) + } + } + return NewInterfaceResult(nil, fmt.Errorf("OperationAttrs found error:%s", name)) +} + +func WithExpire(t time.Duration) *OperationAttr { + return &OperationAttr{ + Name: AttrExpr, + Value: t, + } +} + +func WithNX() *OperationAttr { + return &OperationAttr{ + Name: AttrNx, + Value: empty{}, + } +} + +func WithXX() *OperationAttr { + return &OperationAttr{ + Name: AttrXx, + Value: empty{}, + } +} diff --git a/pkg/myRedis/SimpleCache.go b/pkg/myRedis/SimpleCache.go new file mode 100644 index 0000000..9b49aac --- /dev/null +++ b/pkg/myRedis/SimpleCache.go @@ -0,0 +1,132 @@ +package myRedis + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "time" +) + +const ( + SerializerNot = "" + SerializerJson = "json" + SerializerGob = "gob" +) + +type CacheGetterFunc func() interface{} + +type SimpleCache struct { + Operation *StringOperation + Expire time.Duration + CacheGetter CacheGetterFunc + Serializer string //序列化方式 +} + +func NewSimpleCache(operation *StringOperation, expire time.Duration, serializer string) *SimpleCache { + return &SimpleCache{Operation: operation, Expire: expire, Serializer: serializer} +} + +func (t *SimpleCache) SetCacheGetterFunc(f CacheGetterFunc) *SimpleCache { + t.CacheGetter = f + return t +} + +// 设置缓存 +func (t *SimpleCache) SetCache(key string, value interface{}) { + //if t.Serializer == SerializerNot { + //} + if t.Serializer == SerializerJson { + f := func() string { + j, e := json.Marshal(value) + if e != nil { + return e.Error() + } else { + return string(j) + } + } + t.Operation.Set(key, f(), WithExpire(t.Expire)).Unwrap() + } else if t.Serializer == SerializerGob { + f := func() string { + var buf = &bytes.Buffer{} + enc := gob.NewEncoder(buf) + if err := enc.Encode(value); err != nil { + return "" + } + return buf.String() + } + t.Operation.Set(key, f(), WithExpire(t.Expire)).Unwrap() + } else { + t.Operation.Set(key, value, WithExpire(t.Expire)).Unwrap() + } +} + +func (t *SimpleCache) GetCache(key string) (ret interface{}) { + //如果没有设置的话 + if t.CacheGetter == nil { + panic("没有设置CacheGetter") + } + + if t.Serializer == SerializerNot { + } + if t.Serializer == SerializerJson { + f := func() string { + j, e := json.Marshal(t.CacheGetter()) + if e != nil { + return e.Error() + } else { + return string(j) + } + } + ret = t.Operation.Get(key).UnwrapOrElse(func() string { + data := f() + t.Operation.Set(key, data, WithExpire(t.Expire)).Unwrap() + return data + }) + } + + if t.Serializer == SerializerGob { + f := func() string { + var buf = &bytes.Buffer{} + enc := gob.NewEncoder(buf) + if err := enc.Encode(t.CacheGetter()); err != nil { + return "" + } + return buf.String() + } + ret = t.Operation.Get(key).UnwrapOrElse(func() string { + data := f() + t.Operation.Set(key, data, WithExpire(t.Expire)).Unwrap() + return data + }) + } + + return +} + +func (t *SimpleCache) DelCache(key string) int64 { + return t.Operation.Del(key).UnwrapOr(0) +} + +func (t *SimpleCache) GetCacheForObject(key string, obj interface{}) interface{} { + ret := t.GetCache(key) + if ret == nil { + return nil + } + if t.Serializer == SerializerNot { + obj = ret + } else if t.Serializer == SerializerJson { + err := json.Unmarshal([]byte(ret.(string)), obj) + if err != nil { + return nil + } + } else if t.Serializer == SerializerGob { + + var buf = &bytes.Buffer{} + buf.WriteString(ret.(string)) + dec := gob.NewDecoder(buf) + if dec.Decode(obj) != nil { + return nil + } + } + return nil +} diff --git a/pkg/myRedis/SliceResult.go b/pkg/myRedis/SliceResult.go new file mode 100644 index 0000000..03faee9 --- /dev/null +++ b/pkg/myRedis/SliceResult.go @@ -0,0 +1,29 @@ +package myRedis + +type SliceResult struct { + Result []interface{} + Err error +} + +func NewSliceResult(result []interface{}, err error) *SliceResult { + return &SliceResult{Result: result, Err: err} +} + +func (t *SliceResult) Unwrap() []interface{} { + if t.Err != nil { + panic(t.Err) + } + return t.Result +} + +func (t *SliceResult) UnwrapOr(strs []interface{}) []interface{} { + if t.Err != nil { + return strs + } else { + return t.Result + } +} + +func (t *SliceResult) Iter() *Iterator { + return NewIterator(t.Result) +} diff --git a/pkg/myRedis/StringCache.go b/pkg/myRedis/StringCache.go new file mode 100644 index 0000000..b3f0eeb --- /dev/null +++ b/pkg/myRedis/StringCache.go @@ -0,0 +1,52 @@ +package myRedis + +import ( + "github.com/redis/go-redis/v9" + "time" +) + +type StringCache struct { + Operation *StringOperation + Expire time.Duration + DefaultString string +} + +func NewStringCache(redisClient *redis.Client) *StringCache { + return &StringCache{ + Operation: NewStringOperation(redisClient), + Expire: time.Second * 0, + DefaultString: "", + } +} + +func (t *StringCache) SetExpire(expire time.Duration) *StringCache { + t.Expire = expire + return t +} + +func (t *StringCache) SetDefaultString(defaultString string) *StringCache { + t.DefaultString = defaultString + return t +} + +func (t *StringCache) SetCache(key string, value string) { + t.Operation.Set(key, value, WithExpire(t.Expire)) +} + +func (t *StringCache) GetCache(key string) (ret string) { + ret = t.Operation.Get(key).UnwrapOrElse(func() string { + if t.DefaultString != "" { + t.SetCache(key, t.DefaultString) + } + return t.DefaultString + }) + return +} + +func (t *StringCache) IsExist(key string) bool { + return t.Operation.Exist(key).UnwrapOr(0) != 0 +} + +func (t *StringCache) DelCache(key string) int64 { + return t.Operation.Del(key).UnwrapOr(0) +} diff --git a/pkg/myRedis/StringOperation.go b/pkg/myRedis/StringOperation.go new file mode 100644 index 0000000..09e5434 --- /dev/null +++ b/pkg/myRedis/StringOperation.go @@ -0,0 +1,48 @@ +package myRedis + +import ( + "context" + "github.com/redis/go-redis/v9" + "time" +) + +type StringOperation struct { + ctx context.Context + client *redis.Client +} + +func NewStringOperation(client *redis.Client) *StringOperation { + return &StringOperation{ctx: context.Background(), client: client} +} + +func (t *StringOperation) Set(key string, value interface{}, attrs ...*OperationAttr) *InterfaceResult { + exp := OperationAttrs(attrs).Find(AttrExpr).UnwrapOr(0 * time.Second).(time.Duration) + + nx := OperationAttrs(attrs).Find(AttrNx).UnwrapOr(nil) + if nx != nil { + return NewInterfaceResult(t.client.SetNX(t.ctx, key, value, exp).Result()) + } + + xx := OperationAttrs(attrs).Find(AttrXx).UnwrapOr(nil) + if xx != nil { + return NewInterfaceResult(t.client.SetXX(t.ctx, key, value, exp).Result()) + } + + return NewInterfaceResult(t.client.Set(t.ctx, key, value, exp).Result()) +} + +func (t *StringOperation) Get(key string) *StringResult { + return NewStringResult(t.client.Get(t.ctx, key).Result()) +} + +func (t *StringOperation) MGet(key ...string) *SliceResult { + return NewSliceResult(t.client.MGet(t.ctx, key...).Result()) +} + +func (t *StringOperation) Del(key string) *IntResult { + return NewIntResult(t.client.Del(t.ctx, key).Result()) +} + +func (t *StringOperation) Exist(key string) *IntResult { + return NewIntResult(t.client.Exists(t.ctx, key).Result()) +} diff --git a/pkg/myRedis/StringResult.go b/pkg/myRedis/StringResult.go new file mode 100644 index 0000000..643e3da --- /dev/null +++ b/pkg/myRedis/StringResult.go @@ -0,0 +1,33 @@ +package myRedis + +type StringResult struct { + Result string + Err error +} + +func NewStringResult(result string, err error) *StringResult { + return &StringResult{Result: result, Err: err} +} + +func (t *StringResult) Unwrap() string { + if t.Err != nil { + panic(t.Err) + } + + return t.Result +} + +func (t *StringResult) UnwrapOr(str string) string { + if t.Err != nil { + return str + } else { + return t.Result + } +} + +func (t *StringResult) UnwrapOrElse(f func() string) string { + if t.Err != nil { + return f() + } + return t.Result +} diff --git a/pkg/myRedis/cli.go b/pkg/myRedis/cli.go new file mode 100644 index 0000000..a364d4b --- /dev/null +++ b/pkg/myRedis/cli.go @@ -0,0 +1,33 @@ +package myRedis + +import ( + "github.com/redis/go-redis/v9" + "time" +) + +type Client struct { + client *redis.Client +} + +func NewClient(client *redis.Client) *Client { + if client == nil { + panic("redis client is nil") + } + return &Client{client: client} +} + +func (t *Client) GetClient() *redis.Client { + return t.client +} + +func (t *Client) NewStringCache() *StringCache { + return NewStringCache(t.client) +} + +func (t *Client) NewJsonCache(expire time.Duration) *SimpleCache { + return NewSimpleCache(NewStringOperation(t.client), expire, SerializerJson) +} + +func (t *Client) NewGobCache(expire time.Duration) *SimpleCache { + return NewSimpleCache(NewStringOperation(t.client), expire, SerializerGob) +} diff --git a/pkg/myRedis/index.go b/pkg/myRedis/index.go new file mode 100644 index 0000000..1ecd8fa --- /dev/null +++ b/pkg/myRedis/index.go @@ -0,0 +1,60 @@ +package myRedis + +import ( + "context" + "fmt" + "github.com/redis/go-redis/v9" + "log" + "sync" +) + +type SimpleRedis struct { + Host string + Password string + hosts map[int]*Hosts +} + +type Hosts struct { + clientOnce sync.Once + client *redis.Client +} + +func NewSimpleRedis(host string, password string) *SimpleRedis { + return &SimpleRedis{Host: host, Password: password, hosts: map[int]*Hosts{}} +} + +func (t *SimpleRedis) connectRedis(index int) *redis.Client { + if t.hosts[index] == nil { + t.hosts[index] = &Hosts{ + clientOnce: sync.Once{}, + client: nil, + } + } + + t.hosts[index].clientOnce.Do(func() { + redisClient := redis.NewClient(&redis.Options{ + Addr: t.Host, + Password: t.Password, // no password set + DB: index, // use default DB + //连接池容量以闲置链接数量 + PoolSize: 15, + MinIdleConns: 10, + }) + pong, err := redisClient.Ping(context.Background()).Result() + if err != nil { + panic(fmt.Errorf("connect error:%s", err)) + } + log.Println(fmt.Sprintf("redis newClient success, index:%d, pong: %s", index, pong)) + + t.hosts[index].client = redisClient + }) + return t.hosts[index].client +} + +func (t *SimpleRedis) ConnectDefaultRedis() *Client { + return t.GetRedisClient(0) +} + +func (t *SimpleRedis) GetRedisClient(index int) *Client { + return NewClient(t.connectRedis(index)) +} diff --git a/pkg/mySql/Create.go b/pkg/mySql/Create.go new file mode 100644 index 0000000..d2fd129 --- /dev/null +++ b/pkg/mySql/Create.go @@ -0,0 +1,9 @@ +package mySql + +import ( + "gorm.io/gorm" +) + +func Create(v interface{}, db *gorm.DB) error { + return db.Create(v).Error +} diff --git a/pkg/mySql/Delete.go b/pkg/mySql/Delete.go new file mode 100644 index 0000000..1ccf25b --- /dev/null +++ b/pkg/mySql/Delete.go @@ -0,0 +1,55 @@ +package mySql + +import ( + "errors" + "fmt" + "github.com/Masterminds/squirrel" + "gorm.io/gorm" +) + +type Delete struct { + tableName string + wheres []*Where + db *gorm.DB +} + +func NewDelete(tableName string, db *gorm.DB) *Delete { + return &Delete{tableName: tableName, db: db} +} + +func (t *Delete) WhereRaw(formula string, values ...interface{}) *Delete { + t.wheres = append(t.wheres, NewWhere(formula, values...)) + return t +} + +func (t *Delete) WhereColumn(formula string, values ...interface{}) *Delete { + t.WhereRaw(fmt.Sprintf("`%s` = ?", formula), values...) + return t +} + +func (t *Delete) WhereColumnIn(formula string, values ...interface{}) *Delete { + t.WhereRaw(fmt.Sprintf("`%s` IN ?", formula), values...) + return t +} + +func (t *Delete) Delete() error { + mapper, err := t.Mapper() + if err != nil { + return err + } + return mapper.Exec().Error +} +func (t *Delete) Mapper() (*SqlMapper, error) { + if len(t.wheres) == 0 || t.wheres[0].Formula == "" { + return nil, errors.New("没有where条件不被允许") + } + + squ := squirrel.Delete(fmt.Sprintf("`%s`", t.tableName)) + + //拼接where + for _, where := range t.wheres { + squ = squ.Where(where.Formula, where.Values...) + } + + return Mapper(squ.ToSql()).setDB(t.db), nil +} diff --git a/pkg/mySql/Facade.go b/pkg/mySql/Facade.go new file mode 100644 index 0000000..3508929 --- /dev/null +++ b/pkg/mySql/Facade.go @@ -0,0 +1,32 @@ +package mySql + +import "gorm.io/gorm" + +type Facade struct { + tableName string + dbFunc func() *gorm.DB +} + +func NewFacade(tableName string, dbFunc func() *gorm.DB) *Facade { + return &Facade{tableName: tableName, dbFunc: dbFunc} +} + +func (t *Facade) NewQuery() *Query { + return NewQuery(t.tableName, t.dbFunc()) +} + +func (t *Facade) NewUpdate() *Update { + return NewUpdate(t.tableName, t.dbFunc()) +} + +func (t *Facade) NewInsert() *Insert { + return NewInsert(t.tableName, t.dbFunc()) +} + +func (t *Facade) NewDelete() *Delete { + return NewDelete(t.tableName, t.dbFunc()) +} + +func (t *Facade) Create(val interface{}) error { + return Create(val, t.dbFunc()) +} diff --git a/pkg/mySql/Insert.go b/pkg/mySql/Insert.go new file mode 100644 index 0000000..297d3b9 --- /dev/null +++ b/pkg/mySql/Insert.go @@ -0,0 +1,83 @@ +package mySql + +import ( + "fmt" + "github.com/Masterminds/squirrel" + "gorm.io/gorm" + "reflect" + "strings" + "time" +) + +type Insert struct { + tableName string + columns []*Column + db *gorm.DB +} + +func NewInsert(tableName string, db *gorm.DB) *Insert { + return &Insert{tableName: tableName, db: db} +} + +func (t *Insert) AddColumn(column string, value interface{}) *Insert { + t.columns = append(t.columns, NewColumn(fmt.Sprintf("`%s`", column), value)) + return t +} + +func (t *Insert) AddCreatedColumn() *Insert { + return t.AddColumn("created_at", time.Now()) +} + +func (t *Insert) AddUpdatedColumn() *Insert { + return t.AddColumn("updated_at", time.Now()) +} + +func (t *Insert) AddCreatedAndUpdatedColumns() *Insert { + return t.AddCreatedColumn().AddUpdatedColumn() +} + +func (t *Insert) AddColumns(value map[string]interface{}) *Insert { + for s, i := range value { + t.AddColumn(s, i) + } + return t +} + +func (t *Insert) MountColumn(data interface{}) *Insert { + v := reflect.ValueOf(data) + v = v.Elem() + for i := 0; i < v.NumField(); i++ { + column := v.Type().Field(i).Tag.Get("column") + if column != "" { + split := NewStringSplit(strings.ToUpper(column), ":") + split.RunCount1Func(func(str string) { + t.AddColumn(column, v.Field(i).Interface()) + }) + split.RunCount2Func(func(str1, str2 string) { + if str2 == "KEY" { + t.AddColumn(str1, nil) + } else { + t.AddColumn(str1, v.Field(i).Interface()) + } + }) + } + } + return t +} + +func (t *Insert) Insert() error { + return t.Mapper().Exec().Error +} + +func (t *Insert) Mapper() *SqlMapper { + columns := make([]string, 0) + values := make([]interface{}, 0) + + for _, v := range t.columns { + columns = append(columns, v.name) + values = append(values, v.value) + } + + squ := squirrel.Insert(fmt.Sprintf("`%s`", t.tableName)).Columns(columns...).Values(values...) + return Mapper(squ.ToSql()).setDB(t.db) +} diff --git a/pkg/mySql/InsertColumn.go b/pkg/mySql/InsertColumn.go new file mode 100644 index 0000000..cf64eb6 --- /dev/null +++ b/pkg/mySql/InsertColumn.go @@ -0,0 +1,10 @@ +package mySql + +type Column struct { + name string + value interface{} +} + +func NewColumn(name string, value interface{}) *Column { + return &Column{name: name, value: value} +} diff --git a/pkg/mySql/Join.go b/pkg/mySql/Join.go new file mode 100644 index 0000000..4e9ea54 --- /dev/null +++ b/pkg/mySql/Join.go @@ -0,0 +1,11 @@ +package mySql + +type Join struct { + option string + join string + rest []interface{} +} + +func NewJoin(option string, join string, rest []interface{}) *Join { + return &Join{option: option, join: join, rest: rest} +} diff --git a/pkg/mySql/PageStruct.go b/pkg/mySql/PageStruct.go new file mode 100644 index 0000000..a3c110a --- /dev/null +++ b/pkg/mySql/PageStruct.go @@ -0,0 +1,8 @@ +package mySql + +type PageStruct struct { + Total int64 + TotalPage int64 + Page int64 + PageSize int64 +} diff --git a/pkg/mySql/Query.go b/pkg/mySql/Query.go new file mode 100644 index 0000000..5371f89 --- /dev/null +++ b/pkg/mySql/Query.go @@ -0,0 +1,420 @@ +package mySql + +import ( + "fmt" + "github.com/Masterminds/squirrel" + "gorm.io/gorm" + "math" + "reflect" + "strconv" + "strings" +) + +type Query struct { + tableName string + rawColumns []string + autoColumns []string + rawWheres []*Where + autoWheres []*Where + limit int64 + offset int64 + groupBy string + orderBy string + join []*Join + isPageMode bool + page int64 + pageSize int64 + alias string + db *gorm.DB +} + +func NewQuery(tableName string, db *gorm.DB) *Query { + return &Query{ + tableName: tableName, + rawColumns: []string{}, + autoColumns: []string{}, + rawWheres: make([]*Where, 0), + autoWheres: make([]*Where, 0), + limit: 0, + offset: 0, + join: make([]*Join, 0), + isPageMode: false, + page: 1, + pageSize: 30, + db: db, + } +} + +func (t *Query) SetDB(db *gorm.DB) *Query { + t.db = db + return t +} + +func (t *Query) As(alias string) *Query { + t.alias = alias + return t +} + +func (t *Query) Select(columns ...string) *Query { + t.autoColumns = []string{} + for _, column := range columns { + t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", column)) + } + return t +} + +func (t *Query) SelectRow(columns ...string) *Query { + t.rawColumns = columns + return t +} + +func (t *Query) AddSelect(columns ...string) *Query { + for _, v := range columns { + t.autoColumns = append(t.autoColumns, fmt.Sprintf("`%s`", v)) + } + return t +} +func (t *Query) AddSelectRow(columns ...string) *Query { + for _, v := range columns { + t.rawColumns = append(t.rawColumns, v) + } + return t +} + +func (t *Query) GroupBy(s string) *Query { + if find := strings.Contains(s, "."); find { + countSplit := strings.SplitN(s, ".", 2) + if len(countSplit) == 2 { + t.groupBy = fmt.Sprintf("`%s`.`%s`", countSplit[0], countSplit[1]) + } else { + t.groupBy = fmt.Sprintf("`%s`", s) + } + } else { + t.groupBy = fmt.Sprintf("`%s`", s) + } + return t +} +func (t *Query) GroupByRaw(s string) *Query { + t.groupBy = s + return t +} + +func (t *Query) Limit(s int64) *Query { + t.limit = s + return t +} +func (t *Query) Offset(s int64) *Query { + t.offset = s + return t +} + +func (t *Query) Page(p int64) *Query { + t.isPageMode = true + if p == 0 { + p = 1 + } + t.page = p + return t +} + +func (t *Query) PageSize(p int64) *Query { + if p == 0 { + p = 30 + } + t.pageSize = p + return t +} + +func (t *Query) Pages(p int64, s int64) *Query { + return t.Page(p).PageSize(s) +} + +func (t *Query) MountQuery(f func(squ *Query)) *Query { + f(t) + return t +} + +func (t *Query) NotDelete() *Query { + return t.WhereIsNull("deleted_at") +} + +func (t *Query) autoColumnsToRaw() []string { + newColumns := make([]string, 0) + for _, v := range t.autoColumns { + if find := strings.Contains(v, "."); !find { + if t.alias != "" { + newColumns = append(newColumns, fmt.Sprintf("`%s`.%s", t.alias, v)) + } else { + newColumns = append(newColumns, v) + } + } else { + newColumns = append(newColumns, v) + } + } + return newColumns +} + +func (t *Query) getColumns() []string { + all := append(t.autoColumnsToRaw(), t.rawColumns...) + if len(all) == 0 { + return []string{"*"} + } + return all +} + +func (t *Query) getFrom() string { + if t.alias != "" { + return fmt.Sprintf("`%s` AS `%s`", t.tableName, t.alias) + } + return fmt.Sprintf("`%s`", t.tableName) +} + +func (t *Query) getColName(colName string) string { + if t.alias != "" { + return fmt.Sprintf("`%s`.`%s`", t.alias, colName) + } else { + return fmt.Sprintf("`%s`", colName) + } +} + +func (t *Query) mountWhere(squ squirrel.SelectBuilder) squirrel.SelectBuilder { + //先挂载原生数据 + for _, where := range t.rawWheres { + squ = squ.Where(where.Formula, where.Values...) + } + + //挂载需要加工的数据 + for _, where := range t.autoWheres { + squ = squ.Where(where.Formula, where.Values...) + } + + return squ +} + +func (t *Query) mountOther(squ squirrel.SelectBuilder) squirrel.SelectBuilder { + if t.isPageMode { + squ = squ.Offset(uint64(t.pageSize * (t.page - 1))).Limit(uint64(t.pageSize)) + } else { + if t.limit > 0 { + squ = squ.Limit(uint64(t.limit)) + } + if t.offset > 0 { + squ = squ.Offset(uint64(t.limit)) + } + } + + if t.groupBy != "" { + squ = squ.GroupBy(t.groupBy) + } + if t.orderBy != "" { + squ = squ.OrderBy(t.orderBy) + } + + return squ +} +func (t *Query) mountJoin(squ squirrel.SelectBuilder) squirrel.SelectBuilder { + for _, join := range t.join { + if join.option == "left" { + if len(join.rest) == 0 { + squ = squ.LeftJoin(join.join) + } else { + squ = squ.LeftJoin(join.join, join.rest) + } + } else if join.option == "right" { + if len(join.rest) == 0 { + squ = squ.RightJoin(join.join) + } else { + squ = squ.RightJoin(join.join, join.rest) + } + } else if join.option == "inner" { + if len(join.rest) == 0 { + squ = squ.RightJoin(join.join) + } else { + squ = squ.RightJoin(join.join, join.rest) + } + } else { + if len(join.rest) == 0 { + squ = squ.RightJoin(join.join) + } else { + squ = squ.RightJoin(join.join, join.rest) + } + } + } + return squ +} + +func (t *Query) ToSquirrel() squirrel.SelectBuilder { + //获取select + squ := squirrel.Select(t.getColumns()...).From(t.getFrom()) + //组装where + squ = t.mountWhere(squ) + + //组装limit和offset,groupBy和orderBy + squ = t.mountOther(squ) + + //组装join + squ = t.mountJoin(squ) + return squ +} + +func (t *Query) ToGormQuery() *gorm.DB { + return t.toGormQuery(t.ToSquirrel()) +} + +func (t *Query) ToSql() (string, []interface{}, error) { + return t.ToSquirrel().ToSql() +} + +func (t *Query) toGormQuery(squ squirrel.SelectBuilder) *gorm.DB { + return Mapper(squ.ToSql()).setDB(t.db).Query() +} + +func (t *Query) Find(list interface{}) error { + return t.ToGormQuery().Find(list).Error +} + +func (t *Query) Take(info interface{}) error { + t.limit = 1 + return t.ToGormQuery().Take(info).Error +} + +func (t *Query) GetTotal() int64 { + //只组装where + squ := squirrel.Select("count(*)").From(t.getFrom()) + //组装where + squ = t.mountWhere(squ) + + var count int64 + t.toGormQuery(squ).Scan(&count) + return count +} + +func (t *Query) GetRowCount() int64 { + //只组装where + squ := squirrel.Select("count(*)").From(t.getFrom()) + //组装where + squ = t.mountWhere(squ) + //组装groupBy + if t.groupBy != "" { + squ = squ.GroupBy(t.groupBy) + } + var count int64 + t.toGormQuery(squ).Count(&count) + return count +} + +func (t *Query) GetSum(column string) int64 { + //只组装where + squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom()) + //组装where + squ = t.mountWhere(squ) + var sum int64 + t.toGormQuery(squ).Scan(&sum) + return sum +} + +func (t *Query) GetSumInterface(column string, value interface{}) { + //只组装where + squ := squirrel.Select(fmt.Sprintf("SUM(`%s`)", column)).From(t.getFrom()) + //组装where + squ = t.mountWhere(squ) + t.toGormQuery(squ).Scan(&value) +} + +func (t *Query) PageFind(list interface{}) (*PageStruct, error) { + err1 := t.Find(list) + total := t.GetTotal() + p := &PageStruct{ + Total: total, + TotalPage: 0, + Page: t.page, + PageSize: t.pageSize, + } + + totalPage, err := strconv.ParseFloat(fmt.Sprintf("%.5f", float64(total)/float64(t.pageSize)), 64) + if err == nil { + p.TotalPage = int64(math.Ceil(totalPage)) + } else { + p.TotalPage = (total / p.PageSize) + 1 + } + + return p, err1 +} + +func (t *Query) MountWhereForReflect(data interface{}) *Query { + v := reflect.ValueOf(data) + v = v.Elem() + for i := 0; i < v.NumField(); i++ { + where := v.Type().Field(i).Tag.Get("where") + if where != "" && !isBlank(v.Field(i)) { + split := NewStringSplit(where, ":") + //如果是1段 + split.RunCount1Func(func(whereColName string) { + t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), v.Field(i).Interface()) + }) + + //如果是2段 + split.RunCount2Func(func(whereColName, whereOperation string) { + //str2的值LIKE_BOTH、LIKE、BOOL等 + whereOperationSplit := NewStringSplit(strings.ToUpper(whereOperation), "_") + whereOperationSplit.RunCount1Func(func(operation string) { + //这里的operation是LIKE_BOTH的LIKE + if operation == "LIKE" { + t.WhereRaw( + fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), + v.Field(i).Interface(), + ) + } + if operation == "BOOL" { + if v.Field(i).Interface() == "yes" { + t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 1) + } + if v.Field(i).Interface() == "no" { + t.WhereRaw(fmt.Sprintf("%s = ?", t.getColName(whereColName)), 0) + } + } + }) + whereOperationSplit.RunCount2Func(func(operation, operation2 string) { + //这里的operation是LIKE_BOTH的LIKE,operation2是BOTH + if operation == "LIKE" { + if operation2 == "BOTH" { + t.WhereRaw( + fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), + fmt.Sprintf("%s%s%s", "%", v.Field(i).Interface(), "%"), + ) + } + if operation2 == "LEFT" { + t.WhereRaw( + fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), + fmt.Sprintf("%s%s", "%", v.Field(i).Interface()), + ) + } + if operation2 == "RIGHT" { + t.WhereRaw( + fmt.Sprintf("%s LIKE ?", t.getColName(whereColName)), + fmt.Sprintf("%s%s", v.Field(i).Interface(), "%"), + ) + } + } + }) + }) + + } + } + return t +} +func (t *Query) SelectAllForReflect(data interface{}) *Query { + v := reflect.ValueOf(data) + v = v.Elem() + for i := 0; i < v.NumField(); i++ { + column := v.Type().Field(i).Tag.Get("column") + if column != "" { + //取第一个添加到列 + split := NewStringSplit(column, ":") + split.RunIndexFunc(0, func(str string) { + t.AddSelect(str) + }) + } + } + return t +} diff --git a/pkg/mySql/QueryJoin.go b/pkg/mySql/QueryJoin.go new file mode 100644 index 0000000..d171112 --- /dev/null +++ b/pkg/mySql/QueryJoin.go @@ -0,0 +1,53 @@ +package mySql + +import ( + "fmt" + "strings" +) + +func (t *Query) Join(join string, rest ...interface{}) *Query { + t.join = append(t.join, NewJoin("", join, rest)) + return t +} + +func (t *Query) LeftJoin(join string, rest ...interface{}) *Query { + t.join = append(t.join, NewJoin("left", join, rest)) + return t +} + +func (t *Query) RightJoin(join string, rest ...interface{}) *Query { + t.join = append(t.join, NewJoin("right", join, rest)) + return t +} + +func (t *Query) InnerJoin(join string, rest ...interface{}) *Query { + t.join = append(t.join, NewJoin("inner", join, rest)) + return t +} + +func (t *Query) LeftJoinOn(tableName string, pk string, pk2 string, columns ...string) *Query { + as := string(rune(len(t.join) + 98)) + + for _, column := range columns { + //看一下有没有as + if find := strings.Contains(column, " as "); find { + //劈开前面一段和后面一段 + countSplit := strings.SplitN(column, " as ", 2) + if len(countSplit) == 2 { + t.AddSelectRow(fmt.Sprintf("`%s`.`%s` as `%s`", as, countSplit[0], countSplit[1])) + } else { + t.AddSelectRow(column) + } + + } else { + t.AddSelectRow(fmt.Sprintf("`%s`.`%s`", as, column)) + } + + } + if t.alias == "" { + t.LeftJoin(fmt.Sprintf("`%s` `%s` ON `%s`.`%s` = `%s`.`%s`", tableName, as, as, pk, t.tableName, pk2)) + } else { + t.LeftJoin(fmt.Sprintf("`%s` `%s` ON `%s`.`%s` = `%s`.`%s`", tableName, as, as, pk, t.alias, pk2)) + } + return t +} diff --git a/pkg/mySql/QueryOrderBy.go b/pkg/mySql/QueryOrderBy.go new file mode 100644 index 0000000..d4bd18d --- /dev/null +++ b/pkg/mySql/QueryOrderBy.go @@ -0,0 +1,42 @@ +package mySql + +import ( + "fmt" + "strings" +) + +func (t *Query) OrderByRaw(s string) *Query { + t.orderBy = s + return t +} +func (t *Query) OrderBy(s string, b string) *Query { + //看有没有. + countSplit := strings.SplitN(s, ".", 2) + if len(countSplit) == 2 { + t.orderBy = fmt.Sprintf("`%s`.`%s` %s", countSplit[0], countSplit[1], b) + } else { + t.orderBy = fmt.Sprintf("`%s` %s", s, b) + } + return t +} + +func (t *Query) OrderByDesc(s string) *Query { + t.OrderBy(s, "DESC") + return t +} + +func (t *Query) OrderByAsc(s string) *Query { + t.OrderBy(s, "ASC") + return t +} + +func (t *Query) OrderIsByDesc(isDesc bool, s string) *Query { + if s != "" { + if isDesc { + t.OrderByDesc(s) + } else { + t.OrderByAsc(s) + } + } + return t +} diff --git a/pkg/mySql/QueryWhere.go b/pkg/mySql/QueryWhere.go new file mode 100644 index 0000000..475f59d --- /dev/null +++ b/pkg/mySql/QueryWhere.go @@ -0,0 +1,110 @@ +package mySql + +import ( + "fmt" + "strings" +) + +func (t *Query) makeColName(colName string) string { + //这里的colName其实还没有``包裹 + //如果存在.那就是复合字段 + if find := strings.Contains(colName, "."); find { + countSplit := strings.SplitN(colName, ".", 2) + if len(countSplit) == 2 { + //因为到了where会无脑加上``包裹,所以两侧``是不需要加的 + return fmt.Sprintf("%s`.`%s", countSplit[0], countSplit[1]) + } + } + return colName +} + +func (t *Query) WhereRaw(formula string, values ...interface{}) *Query { + t.rawWheres = append(t.rawWheres, NewWhere(formula, values...)) + return t +} + +func (t *Query) Where(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewEqWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereEq(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewEqWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereNeq(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewNeqWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereGt(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewGtWhere(t.makeColName(colName), value)) + return t +} +func (t *Query) WhereEgt(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewEgtWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereLt(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewLtWhere(t.makeColName(colName), value)) + return t +} +func (t *Query) WhereElt(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewEltWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereNotLike(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewNotLikeWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereLike(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewLikeWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereBetween(colName string, value1 interface{}, value2 interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewBetweenWhere(t.makeColName(colName), value1, value2)) + return t +} + +func (t *Query) WhereNotBetween(colName string, value1 interface{}, value2 interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewNotBetweenWhere(t.makeColName(colName), value1, value2)) + return t +} + +func (t *Query) WhereIn(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewInWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereNotIn(colName string, value interface{}) *Query { + t.autoWheres = append(t.autoWheres, NewNotInWhere(t.makeColName(colName), value)) + return t +} + +func (t *Query) WhereIsNull(colName string) *Query { + t.autoWheres = append(t.autoWheres, NewIsNullWhere(t.makeColName(colName))) + return t +} + +func (t *Query) WhereIsNotNull(colName string) *Query { + t.autoWheres = append(t.autoWheres, NewIsNotNullWhere(t.makeColName(colName))) + return t +} + +func (t *Query) Wheres(value map[string]interface{}) *Query { + for s, i := range value { + t.Where(s, i) + } + return t +} + +func (t *Query) WhereBetweenFunc(colName string, f func() (start interface{}, end interface{})) *Query { + start, end := f() + t.WhereBetween(colName, start, end) + return t +} diff --git a/pkg/mySql/Reflect.go b/pkg/mySql/Reflect.go new file mode 100644 index 0000000..3d4ce52 --- /dev/null +++ b/pkg/mySql/Reflect.go @@ -0,0 +1,24 @@ +package mySql + +import ( + "reflect" +) + +// 判断是不是空值 +func isBlank(value reflect.Value) bool { + switch value.Kind() { + case reflect.String: + return value.Len() == 0 + case reflect.Bool: + return !value.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return value.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return value.Uint() == 0 + case reflect.Float32, reflect.Float64: + return value.Float() == 0 + case reflect.Interface, reflect.Ptr: + return value.IsNil() + } + return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) +} diff --git a/pkg/mySql/SqlMapper.go b/pkg/mySql/SqlMapper.go new file mode 100644 index 0000000..1b40a5c --- /dev/null +++ b/pkg/mySql/SqlMapper.go @@ -0,0 +1,78 @@ +package mySql + +import ( + "errors" + "fmt" + "gorm.io/gorm" +) + +type SqlMapper struct { + Sql string + Args []interface{} + db *gorm.DB +} + +func (t *SqlMapper) setDB(db *gorm.DB) *SqlMapper { + t.db = db + return t +} + +// 查询 +func (t *SqlMapper) Query() *gorm.DB { + return t.db.Raw(t.Sql, t.Args...) +} + +func (t *SqlMapper) Exec() *gorm.DB { + return t.db.Exec(t.Sql, t.Args...) +} + +func NewSqlMapper(sql string, args []interface{}) *SqlMapper { + return &SqlMapper{Sql: sql, Args: args} +} + +func Mapper(sql string, args []interface{}, err error) *SqlMapper { + if err != nil { + panic(err.Error()) + } + return NewSqlMapper(sql, args) +} + +type SqlMappers []*SqlMapper + +func Mappers(sqlMappers ...*SqlMapper) (list SqlMappers) { + list = sqlMappers + return +} +func (t SqlMappers) apply(tx *gorm.DB) { + for _, sql := range t { + sql.setDB(tx) + } +} +func (t SqlMappers) Exec(f func() error) error { + if len(t) == 0 { + return errors.New("无Mapper") + } + //其实是以第一个为准 + return t[0].db.Transaction(func(tx *gorm.DB) error { + t.apply(tx) + return f() + }) +} + +func (t SqlMappers) ExecTransaction() error { + //其实是以第一个为准 + return t[0].db.Transaction(func(tx *gorm.DB) error { + fmt.Println("事务开始") + for _, sql := range t { + sql.setDB(tx) + err := sql.Exec().Error + if err != nil { + fmt.Println("事务结束(失败)") + return err + } + } + fmt.Println("事务成功") + return nil + }) + +} diff --git a/pkg/mySql/StringSplit.go b/pkg/mySql/StringSplit.go new file mode 100644 index 0000000..1090e0c --- /dev/null +++ b/pkg/mySql/StringSplit.go @@ -0,0 +1,56 @@ +package mySql + +import ( + "strings" +) + +type StringSplit struct { + str string + sep string +} + +func NewStringSplit(str string, sep string) *StringSplit { + return &StringSplit{str: str, sep: sep} +} + +func (t *StringSplit) chooseIndexStr(index int) (string, bool) { + //劈开 + countSplit := strings.Split(t.str, t.sep) + + //获取对应的字符串,index是从0开始算。 + if len(countSplit) <= index { + return "", false + } + + return countSplit[index], true +} + +// RunIndexFunc 取第index个,从0开始数,执行。 +func (t *StringSplit) RunIndexFunc(index int, f func(str string)) { + str, exist := t.chooseIndexStr(index) + if exist { + f(str) + } +} + +// RunCountFunc 如果分割出来是count个则执行 +func (t *StringSplit) RunCountFunc(count int, f func(strArr []string)) { + countSplit := strings.Split(t.str, t.sep) + if len(countSplit) == count { + f(countSplit) + } +} + +// RunCount1Func 常用方法封装,快捷方法,count=1 +func (t *StringSplit) RunCount1Func(f func(str string)) { + t.RunCountFunc(1, func(strArr []string) { + f(strArr[0]) + }) +} + +// RunCount2Func 常用方法封装,快捷方法,count=2 +func (t *StringSplit) RunCount2Func(f func(str1, str2 string)) { + t.RunCountFunc(2, func(strArr []string) { + f(strArr[0], strArr[1]) + }) +} diff --git a/pkg/mySql/Update.go b/pkg/mySql/Update.go new file mode 100644 index 0000000..9cb6e4f --- /dev/null +++ b/pkg/mySql/Update.go @@ -0,0 +1,127 @@ +package mySql + +import ( + "errors" + "fmt" + "github.com/Masterminds/squirrel" + "gorm.io/gorm" + "reflect" + "strings" + "time" +) + +type Update struct { + tableName string + columns []*Column + wheres []*Where + db *gorm.DB +} + +func NewUpdate(tableName string, db *gorm.DB) *Update { + return &Update{tableName: tableName, db: db} +} + +func (t *Update) MountColumn(data interface{}) *Update { + v := reflect.ValueOf(data) + v = v.Elem() + for i := 0; i < v.NumField(); i++ { + column := v.Type().Field(i).Tag.Get("column") + if column != "" { + countSplit := strings.SplitN(column, ":", 2) + if len(countSplit) == 2 { + if countSplit[1] != "key" { + t.AddColumn(countSplit[0], v.Field(i).Interface()) + } else { + t.Where(countSplit[0], v.Field(i).Interface()) + } + } else { + t.AddColumn(column, v.Field(i).Interface()) + } + } + } + return t +} + +func (t *Update) AddColumn(column string, value interface{}) *Update { + t.columns = append(t.columns, NewColumn(fmt.Sprintf("`%s`", column), value)) + return t +} + +func (t *Update) AddColumns(m map[string]interface{}) *Update { + for s, i := range m { + t.AddColumn(s, i) + } + return t +} + +func (t *Update) AddColumnInc(column string, inc uint) *Update { + t.columns = append(t.columns, NewColumn( + fmt.Sprintf("`%s`", column), + squirrel.Expr(fmt.Sprintf("`%s` + %d", column, inc)), + )) + return t +} + +func (t *Update) AddColumnDec(column string, dec uint) *Update { + t.columns = append(t.columns, NewColumn( + fmt.Sprintf("`%s`", column), + squirrel.Expr(fmt.Sprintf("`%s` - %d", column, dec)), + )) + return t +} + +func (t *Update) AddCreatedColumn() *Update { + t.AddColumn("created_at", time.Now()) + return t +} + +func (t *Update) AddUpdatedColumn() *Update { + t.AddColumn("updated_at", time.Now()) + return t +} +func (t *Update) AddDeletedColumn() *Update { + t.AddColumn("deleted_at", time.Now()) + return t +} +func (t *Update) WhereRaw(formula string, values ...interface{}) *Update { + t.wheres = append(t.wheres, NewWhere(formula, values...)) + return t +} + +func (t *Update) Where(formula string, values ...interface{}) *Update { + t.WhereRaw(fmt.Sprintf("`%s` = ?", formula), values...) + return t +} + +func (t *Update) Update() error { + mapper, err := t.Mapper() + if err != nil { + return err + } + return mapper.Exec().Error +} + +func (t *Update) toSqu() (squirrel.UpdateBuilder, error) { + if len(t.wheres) == 0 || t.wheres[0].Formula == "" { + return squirrel.UpdateBuilder{}, errors.New("没有where条件不被允许") + } + + squ := squirrel.Update(fmt.Sprintf("`%s`", t.tableName)) + for _, v := range t.columns { + squ = squ.Set(v.name, v.value) + } + + //拼接where + for _, where := range t.wheres { + squ = squ.Where(where.Formula, where.Values...) + } + return squ, nil +} + +func (t *Update) Mapper() (*SqlMapper, error) { + squ, err := t.toSqu() + if err != nil { + return nil, err + } + return Mapper(squ.ToSql()).setDB(t.db), nil +} diff --git a/pkg/mySql/Where.go b/pkg/mySql/Where.go new file mode 100644 index 0000000..af761af --- /dev/null +++ b/pkg/mySql/Where.go @@ -0,0 +1,115 @@ +package mySql + +import "fmt" + +type Where struct { + Formula string + Values []interface{} +} + +func NewWhere(formula string, values ...interface{}) *Where { + return &Where{ + Formula: formula, + Values: values, + } +} + +func NewWheres(where ...*Where) []*Where { + return where +} + +func NewEqWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` = ?", colName), + Values: []interface{}{value}, + } +} + +func NewNeqWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` <> ?", colName), + Values: []interface{}{value}, + } +} + +func NewGtWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` > ?", colName), + Values: []interface{}{value}, + } +} +func NewEgtWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` >= ?", colName), + Values: []interface{}{value}, + } +} + +func NewLtWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` < ?", colName), + Values: []interface{}{value}, + } +} +func NewEltWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` <= ?", colName), + Values: []interface{}{value}, + } +} + +func NewNotLikeWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` NOT LIKE ?", colName), + Values: []interface{}{value}, + } +} + +func NewLikeWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` LIKE ?", colName), + Values: []interface{}{value}, + } +} + +func NewBetweenWhere(colName string, value1 interface{}, value2 interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("(`%s` BETWEEN ? AND ?)", colName), + Values: []interface{}{value1, value2}, + } +} + +func NewNotBetweenWhere(colName string, value1 interface{}, value2 interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("(`%s` NOT BETWEEN ? AND ?)", colName), + Values: []interface{}{value1, value2}, + } +} + +func NewInWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` IN ?", colName), + Values: []interface{}{value}, + } +} + +func NewNotInWhere(colName string, value interface{}) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` NOT IN ?", colName), + Values: []interface{}{value}, + } +} + +func NewIsNullWhere(colName string) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` IS NULL", colName), + Values: []interface{}{}, + } +} + +func NewIsNotNullWhere(colName string) *Where { + return &Where{ + Formula: fmt.Sprintf("`%s` IS NOT NULL ?", colName), + Values: []interface{}{}, + } +} diff --git a/pkg/myUrl/index.go b/pkg/myUrl/index.go new file mode 100644 index 0000000..306e421 --- /dev/null +++ b/pkg/myUrl/index.go @@ -0,0 +1,42 @@ +package myUrl + +import "net/url" + +type UrlCli struct { + url url.URL + params url.Values +} + +func NewUrlCli(scheme string, host string) *UrlCli { + return &UrlCli{url: url.URL{Scheme: scheme, Host: host}, params: url.Values{}} +} + +func NewUrlCliWithParse(data string) (*UrlCli, error) { + urlObj, err := url.Parse(data) + if err != nil { + return nil, err + } + return &UrlCli{ + url: *urlObj, + params: urlObj.Query(), + }, nil +} + +func (t *UrlCli) Set(m map[string]string) { + params := url.Values{} + for s, s2 := range m { + params.Set(s, s2) + } + t.params = params +} + +func (t *UrlCli) Add(key, value string) *UrlCli { + t.params.Set(key, value) + return t +} + +func (t *UrlCli) String() string { + baseURL := t.url + baseURL.RawQuery = t.params.Encode() + return baseURL.String() +} diff --git a/pkg/myViper/viper.go b/pkg/myViper/viper.go new file mode 100644 index 0000000..59d2231 --- /dev/null +++ b/pkg/myViper/viper.go @@ -0,0 +1,31 @@ +package myViper + +import ( + "fmt" + "github.com/spf13/viper" +) + +type SimpleViper struct { + config interface{} + configType string + configName string + configPath string +} + +func NewSimpleViper(config interface{}, configType string, configName string, configPath string) *SimpleViper { + return &SimpleViper{config: config, configType: configType, configName: configName, configPath: configPath} +} + +func (t *SimpleViper) Apply() { + v := viper.New() + v.SetConfigFile(fmt.Sprintf("%s/%s.yaml", t.configPath, t.configName)) + v.SetConfigType(t.configType) + err := v.ReadInConfig() + if err != nil { + panic(fmt.Errorf("Fatal error config file: %s \n", err)) + } + + if err := v.Unmarshal(t.config); err != nil { + panic(fmt.Errorf("Fatal Unmarshal error config file: %s \n", err)) + } +}