diff --git a/global/global.go b/global/global.go index 19d5a18..3bf96df 100644 --- a/global/global.go +++ b/global/global.go @@ -8,11 +8,11 @@ package global import ( "github.com/go-redis/redis" - "github.com/jinzhu/gorm" "github.com/opentracing/opentracing-go" "github.com/spf13/viper" "github.com/viletyy/potato/config" "go.uber.org/zap" + "gorm.io/gorm" ) var ( diff --git a/go.mod b/go.mod index f75b93f..3367669 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.15 require ( github.com/HdrHistogram/hdrhistogram-go v1.1.0 // indirect github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 - github.com/eddycjy/opentracing-gorm v0.0.0-20200209122056-516a807d2182 github.com/fsnotify/fsnotify v1.4.9 github.com/gin-gonic/gin v1.7.7 github.com/go-playground/locales v0.13.0 @@ -13,7 +12,6 @@ require ( github.com/go-playground/validator/v10 v10.6.1 github.com/go-redis/redis v6.15.9+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/jinzhu/gorm v1.9.16 github.com/juju/ratelimit v1.0.1 github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible github.com/lestrrat-go/strftime v1.0.4 // indirect @@ -21,7 +19,6 @@ require ( github.com/onsi/gomega v1.13.0 // indirect github.com/opentracing/opentracing-go v1.2.0 github.com/robfig/cron/v3 v3.0.0 - github.com/smacker/opentracing-gorm v0.0.0-20181207094635-cd4974441042 // indirect github.com/spf13/viper v1.7.1 github.com/swaggo/gin-swagger v1.3.0 github.com/swaggo/swag v1.7.0 @@ -29,10 +26,14 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible // indirect github.com/viletyy/yolk v1.0.1 go.uber.org/atomic v1.8.0 // indirect - go.uber.org/zap v1.17.0 + go.uber.org/zap v1.21.0 golang.org/x/net v0.0.0-20210614182718-04defd469f4e // indirect google.golang.org/grpc v1.38.0 google.golang.org/protobuf v1.26.0 gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df + gorm.io/driver/mysql v1.3.5 + gorm.io/driver/postgres v1.3.8 + gorm.io/gorm v1.23.8 + moul.io/zapgorm2 v1.1.3 ) diff --git a/initialize/gorm.go b/initialize/gorm.go index 40f6006..e99ba15 100644 --- a/initialize/gorm.go +++ b/initialize/gorm.go @@ -9,13 +9,12 @@ package initialize import ( "fmt" - otgorm "github.com/eddycjy/opentracing-gorm" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" "github.com/viletyy/potato/global" - "github.com/viletyy/potato/internal/model" - "github.com/viletyy/potato/internal/model/basic" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" + "moul.io/zapgorm2" ) func Gorm() *gorm.DB { @@ -30,7 +29,8 @@ func Gorm() *gorm.DB { } func GormMysql() *gorm.DB { - db, err := gorm.Open("mysql", fmt.Sprintf("%s:%s@(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", global.GO_CONFIG.Database.User, global.GO_CONFIG.Database.Password, global.GO_CONFIG.Database.Host, global.GO_CONFIG.Database.Port, global.GO_CONFIG.Database.Name)) + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", global.GO_CONFIG.Database.User, global.GO_CONFIG.Database.Password, global.GO_CONFIG.Database.Host, global.GO_CONFIG.Database.Port, global.GO_CONFIG.Database.Name) + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if err != nil { global.GO_LOG.Error(fmt.Sprintf("Mysql Gorm Open Error: %v", err)) } @@ -39,7 +39,8 @@ func GormMysql() *gorm.DB { } func GormPostgresql() *gorm.DB { - db, err := gorm.Open("postgres", fmt.Sprintf("host=%s user=%s dbname=%s port=%d sslmode=disable password=%s", global.GO_CONFIG.Database.Host, global.GO_CONFIG.Database.User, global.GO_CONFIG.Database.Name, global.GO_CONFIG.Database.Port, global.GO_CONFIG.Database.Password)) + dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai", global.GO_CONFIG.Database.Host, global.GO_CONFIG.Database.User, global.GO_CONFIG.Database.Password, global.GO_CONFIG.Database.Name, global.GO_CONFIG.Database.Port) + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { global.GO_LOG.Error(fmt.Sprintf("Postgresql Gorm Open Error: %v", err)) } @@ -48,29 +49,21 @@ func GormPostgresql() *gorm.DB { } func GormSet(db *gorm.DB) { - // 设置表前缀 - gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string { - return global.GO_CONFIG.Database.TablePrefix + defaultTableName + if global.GO_CONFIG.App.RunMode != "debug" { + logger := zapgorm2.New(global.GO_LOG) + logger.SetAsDefault() + logger.LogLevel = gormlogger.Info + db.Logger = logger } - // 设置日志 - if global.GO_CONFIG.App.RunMode == "debug" { - db.LogMode(true) + sqlDB, err := db.DB() + if err != nil { + global.GO_LOG.Error(fmt.Sprintf("Gorm setting db.DB(): %v ", err)) } - // 设置迁移 - db.AutoMigrate( - &basic.Vendor{}, - &model.User{}, - &model.Auth{}, - ) - // 设置空闲连接池中的最大连接数 - db.DB().SetMaxIdleConns(10) + sqlDB.SetConnMaxIdleTime(10) // 设置打开数据库连接的最大数量 - db.DB().SetMaxOpenConns(100) - - // 设置链路追踪 - otgorm.AddGormCallbacks(db) + sqlDB.SetMaxOpenConns(100) } diff --git a/internal/controller/api/v1/basic/vendor.go b/internal/controller/api/v1/basic/vendor.go index a788d8c..452ed3f 100644 --- a/internal/controller/api/v1/basic/vendor.go +++ b/internal/controller/api/v1/basic/vendor.go @@ -8,12 +8,12 @@ package basic import ( "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" "github.com/viletyy/potato/global" "github.com/viletyy/potato/internal/service" "github.com/viletyy/potato/pkg/app" "github.com/viletyy/potato/pkg/errcode" "github.com/viletyy/yolk/convert" + "gorm.io/gorm" ) type Vendor struct{} diff --git a/internal/controller/api/v1/user.go b/internal/controller/api/v1/user.go index c7f2963..da976a1 100644 --- a/internal/controller/api/v1/user.go +++ b/internal/controller/api/v1/user.go @@ -8,12 +8,12 @@ package v1 import ( "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" "github.com/viletyy/potato/global" "github.com/viletyy/potato/internal/service" "github.com/viletyy/potato/pkg/app" "github.com/viletyy/potato/pkg/errcode" "github.com/viletyy/yolk/convert" + "gorm.io/gorm" ) type User struct{} diff --git a/internal/dao/dao.go b/internal/dao/dao.go index 763c722..de65bd8 100644 --- a/internal/dao/dao.go +++ b/internal/dao/dao.go @@ -6,7 +6,7 @@ */ package dao -import "github.com/jinzhu/gorm" +import "gorm.io/gorm" type Dao struct { Engine *gorm.DB diff --git a/internal/dao/user.go b/internal/dao/user.go index b4d2a2b..bb193cf 100644 --- a/internal/dao/user.go +++ b/internal/dao/user.go @@ -31,7 +31,7 @@ func (d *Dao) LoginUser(username string, password string) (model.User, error) { return user.GetByUsernameAndPassword(d.Engine) } -func (d *Dao) CountUser(username, nickname string) (int, error) { +func (d *Dao) CountUser(username, nickname string) (int64, error) { vendor := model.User{Username: username, Nickname: nickname} return vendor.Count(d.Engine) } diff --git a/internal/dao/vendor.go b/internal/dao/vendor.go index 750d786..5594575 100644 --- a/internal/dao/vendor.go +++ b/internal/dao/vendor.go @@ -12,7 +12,7 @@ import ( "github.com/viletyy/potato/pkg/app" ) -func (d *Dao) CountVendor(name string, uuid int) (int, error) { +func (d *Dao) CountVendor(name string, uuid int) (int64, error) { vendor := basic.Vendor{Name: name, Uuid: uuid} return vendor.Count(d.Engine) } diff --git a/internal/model/auth.go b/internal/model/auth.go index b9427a2..8f3e481 100644 --- a/internal/model/auth.go +++ b/internal/model/auth.go @@ -6,7 +6,7 @@ */ package model -import "github.com/jinzhu/gorm" +import "gorm.io/gorm" type Auth struct { *Model @@ -15,8 +15,8 @@ type Auth struct { } func (a Auth) Get(db *gorm.DB) (auth Auth, err error) { - if notFound := db.Where("app_key = ? AND app_secret = ?", a.AppKey, a.AppSecret).First(&auth).RecordNotFound(); notFound { - return a, gorm.ErrRecordNotFound + if err := db.Where("app_key = ? AND app_secret = ?", a.AppKey, a.AppSecret).First(&auth).Error; err != nil { + return a, err } return auth, nil diff --git a/internal/model/basic/vendor.go b/internal/model/basic/vendor.go index 0ecbcc1..2f94ed3 100644 --- a/internal/model/basic/vendor.go +++ b/internal/model/basic/vendor.go @@ -7,8 +7,8 @@ package basic import ( - "github.com/jinzhu/gorm" "github.com/viletyy/potato/internal/model" + "gorm.io/gorm" ) type Vendor struct { @@ -18,8 +18,8 @@ type Vendor struct { Uuid int `json:"uuid"` } -func (v Vendor) Count(db *gorm.DB) (int, error) { - var count int +func (v Vendor) Count(db *gorm.DB) (int64, error) { + var count int64 if v.Name != "" { db = db.Where("name = ?", v.Name) } @@ -51,8 +51,8 @@ func (v Vendor) List(db *gorm.DB, pageOffset, pageSize int) (vendors []Vendor, e } func (v Vendor) Get(db *gorm.DB) (vendor Vendor, err error) { - if notFound := db.Where("id = ?", v.ID).First(&vendor).RecordNotFound(); notFound { - return v, gorm.ErrRecordNotFound + if err := db.Where("id = ?", v.ID).First(&vendor).Error; err != nil { + return v, err } return vendor, nil diff --git a/internal/model/user.go b/internal/model/user.go index d2acee0..2dba47d 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -6,7 +6,7 @@ */ package model -import "github.com/jinzhu/gorm" +import "gorm.io/gorm" type User struct { *Model @@ -17,8 +17,8 @@ type User struct { IsAdmin bool `json:"is_admin" gorm:"default: false"` } -func (u User) Count(db *gorm.DB) (int, error) { - var count int +func (u User) Count(db *gorm.DB) (int64, error) { + var count int64 if u.Username != "" { db = db.Where("username = ?", u.Username) } @@ -51,16 +51,16 @@ func (u User) List(db *gorm.DB, pageOffset, pageSize int) (users []User, err err } func (u User) GetByUsernameAndPassword(db *gorm.DB) (user User, err error) { - if notFound := db.Where("username = ? AND password = ?", u.Username, u.Password).First(&user).RecordNotFound(); notFound { - return u, gorm.ErrRecordNotFound + if err := db.Where("username = ? AND password = ?", u.Username, u.Password).First(&user).Error; err != nil { + return u, err } return user, nil } func (u User) Get(db *gorm.DB) (user User, err error) { - if notFound := db.Where("id = ?", u.ID).First(&user).RecordNotFound(); notFound { - return u, gorm.ErrRecordNotFound + if err := db.Where("id = ?", u.ID).First(&user).Error; err != nil { + return u, err } return user, nil @@ -71,8 +71,7 @@ func (u *User) Create(db *gorm.DB) error { } func (u *User) Update(db *gorm.DB) error { - err := db.Model(&User{}).Update(u).Error - return err + return db.Save(u).Error } func (u *User) Delete(db *gorm.DB) error { diff --git a/internal/service/service.go b/internal/service/service.go index 7a5698e..b48397d 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -9,7 +9,6 @@ package service import ( "context" - otgorm "github.com/eddycjy/opentracing-gorm" "github.com/viletyy/potato/global" "github.com/viletyy/potato/internal/dao" ) @@ -21,7 +20,7 @@ type Service struct { func New(ctx context.Context) Service { svc := Service{Ctx: ctx} - svc.Dao = dao.New(otgorm.WithContext(svc.Ctx, global.GO_DB)) + svc.Dao = dao.New(global.GO_DB) return svc } diff --git a/internal/service/user.go b/internal/service/user.go index 954debc..2384134 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -63,7 +63,7 @@ func (svc *Service) LoginUser(param *UserLoginRequest) (model.User, error) { return svc.Dao.LoginUser(param.Username, param.Password) } -func (svc *Service) CountUser(param *CountUserRequest) (int, error) { +func (svc *Service) CountUser(param *CountUserRequest) (int64, error) { return svc.Dao.CountUser(param.Username, param.Nickname) } diff --git a/internal/service/vendor.go b/internal/service/vendor.go index 562571a..7823f22 100644 --- a/internal/service/vendor.go +++ b/internal/service/vendor.go @@ -40,7 +40,7 @@ type DeleteVendorRequest struct { ID int64 `json:"id" validate:"required,gte=1"` } -func (svc *Service) CountVendor(param *CountVendorRequest) (int, error) { +func (svc *Service) CountVendor(param *CountVendorRequest) (int64, error) { return svc.Dao.CountVendor(param.Name, param.Uuid) } diff --git a/main.go b/main.go index dcd27e9..a94b1b1 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "github.com/viletyy/potato/global" "github.com/viletyy/potato/initialize" + "github.com/viletyy/potato/migrations" "github.com/viletyy/yolk/convert" ) @@ -29,7 +30,15 @@ func main() { global.GO_TRACER = initialize.Tracer() go initialize.Cron() - defer global.GO_DB.Close() + if err := migrations.Migrate(global.GO_DB); err != nil { + global.GO_LOG.Sugar().Fatalf("migrations.Migrate: %v", err) + } + + sqlDB, err := global.GO_DB.DB() + if err != nil { + global.GO_LOG.Sugar().Fatalf("global.GO_DB.DB err: %v", err) + } + defer sqlDB.Close() defer global.GO_REDIS.Close() flag.StringVar(&grpcPort, "grpc_port", convert.ToString(global.GO_CONFIG.Server.GrpcPort), "启动grpc服务端口号") @@ -37,6 +46,21 @@ func main() { flag.Parse() + // defaultMailer := email.NewEmail(&email.SMTPInfo{ + // Host: global.GO_CONFIG.Email.Host, + // Port: global.GO_CONFIG.Email.Port, + // IsSSL: global.GO_CONFIG.Email.IsSSL, + // UserName: global.GO_CONFIG.Email.UserName, + // Password: global.GO_CONFIG.Email.Password, + // From: global.GO_CONFIG.Email.From, + // }) + + // _ = defaultMailer.SendMail( + // global.GO_CONFIG.Email.To, + // fmt.Sprintf("异常抛出,发生时间:%d", time.Now().Unix()), + // fmt.Sprintf("错误信息:heheh%s", "dfds"), + // ) + errs := make(chan error) go func() { err := initialize.RunHttpServer(httpPort) diff --git a/migrations/migration.go b/migrations/migration.go new file mode 100644 index 0000000..a24f656 --- /dev/null +++ b/migrations/migration.go @@ -0,0 +1,125 @@ +package migrations + +import ( + "fmt" + "log" + "os" + + "github.com/viletyy/potato/global" + "gorm.io/gorm" +) + +const minDBVersion = 0 + +type Migration interface { + Description() string + Migrate(*gorm.DB) error +} + +type migration struct { + description string + migrate func(*gorm.DB) error +} + +func NewMigration(desc string, fn func(*gorm.DB) error) Migration { + return &migration{desc, fn} +} + +func (m *migration) Description() string { + return m.description +} + +func (m *migration) Migrate(db *gorm.DB) error { + return m.migrate(db) +} + +type Version struct { + ID int64 `gorm:"primary_key"` + Version int64 +} + +var migrations = []Migration{ + NewMigration("create table users", createTableUsers), + NewMigration("create table auths", createTableAuths), + NewMigration("create table vendors", createTableVendors), +} + +func GetCurrentDBVersion(db *gorm.DB) (int64, error) { + if err := db.Debug().AutoMigrate(&Version{}); err != nil { + return -1, fmt.Errorf("db.AutoMigrate: %v", err) + } + + currentVersion := &Version{ID: 1} + if err := db.Debug().First(currentVersion).Error; err != nil { + return -1, fmt.Errorf("db.First: %v", err) + } + + return currentVersion.Version, nil +} + +func ExpectedVersion() int64 { + return int64(minDBVersion + len(migrations)) +} + +func EnsureUpTodate(db *gorm.DB) error { + currentDB, err := GetCurrentDBVersion(db) + if err != nil { + return err + } + + if currentDB < 0 { + return fmt.Errorf("Database has not been initialised") + } + + if minDBVersion > currentDB { + return fmt.Errorf("DB version %d (<= %d) is too old for auto-migration.", currentDB, minDBVersion) + } + + expected := ExpectedVersion() + + if currentDB != expected { + return fmt.Errorf(`Current database version %d is not equal to the expected version %d. `, currentDB, expected) + } + + return nil +} + +func Migrate(db *gorm.DB) error { + if err := db.AutoMigrate(&Version{}); err != nil { + return fmt.Errorf("db.AutoMigrate: %v", err) + } + + currentVersion := &Version{ID: 1} + if err := db.First(currentVersion).Error; err != nil { + currentVersion.Version = 0 + if err := db.Debug().Create(currentVersion).Error; err != nil { + return fmt.Errorf("db.Create: %v", err) + } + } + + v := currentVersion.Version + if minDBVersion > v { + global.GO_LOG.Fatal("Please upgrade the latest code.") + } + + if int(v-minDBVersion) > len(migrations) { + msg := fmt.Sprintf("Downgrading database version from '%d' to '%d' is not supported and may result in loss of data integrity.\nIf you really know what you're doing, execute `UPDATE version SET version=%d WHERE id=1;`\n", + v, minDBVersion+len(migrations), minDBVersion+len(migrations)) + fmt.Fprint(os.Stderr, msg) + log.Fatal(msg) + return nil + } + + for i, m := range migrations[v-minDBVersion:] { + global.GO_LOG.Sugar().Infof("Migration[%d]: %s", v+int64(i), m.Description()) + if err := m.Migrate(db); err != nil { + return fmt.Errorf("Migrate: %v", err) + } + currentVersion.Version = v + int64(i) + 1 + if err := db.Save(currentVersion).Error; err != nil { + return err + } + } + + return nil +} diff --git a/migrations/v1.go b/migrations/v1.go new file mode 100644 index 0000000..b757782 --- /dev/null +++ b/migrations/v1.go @@ -0,0 +1,25 @@ +package migrations + +import ( + "fmt" + + "github.com/viletyy/potato/internal/model" + "gorm.io/gorm" +) + +func createTableUsers(db *gorm.DB) (err error) { + type User struct { + *model.Model + + Username string `json:"username"` + Password string `json:"-"` + Nickname string `json:"nickname"` + IsAdmin bool `json:"is_admin" gorm:"default: false"` + } + + if err := db.Debug().AutoMigrate(&User{}); err != nil { + return fmt.Errorf("migrations: create table users err: %v", err) + } + + return nil +} diff --git a/migrations/v2.go b/migrations/v2.go new file mode 100644 index 0000000..652ff17 --- /dev/null +++ b/migrations/v2.go @@ -0,0 +1,23 @@ +package migrations + +import ( + "fmt" + + "github.com/viletyy/potato/internal/model" + "gorm.io/gorm" +) + +func createTableAuths(db *gorm.DB) (err error) { + type Auth struct { + *model.Model + + AppKey string `json:"app_key"` + AppSecret string `json:"app_secret"` + } + + if err := db.Debug().AutoMigrate(&Auth{}); err != nil { + return fmt.Errorf("migrations: create table auths err: %v", err) + } + + return nil +} diff --git a/migrations/v3.go b/migrations/v3.go new file mode 100644 index 0000000..e5a1467 --- /dev/null +++ b/migrations/v3.go @@ -0,0 +1,23 @@ +package migrations + +import ( + "fmt" + + "github.com/viletyy/potato/internal/model" + "gorm.io/gorm" +) + +func createTableVendors(db *gorm.DB) (err error) { + type Vendor struct { + *model.Model + + Name string `json:"name"` + Uuid int `json:"uuid"` + } + + if err := db.Debug().AutoMigrate(&Vendor{}); err != nil { + return fmt.Errorf("migrations: create table auths err: %v", err) + } + + return nil +} diff --git a/pkg/app/app.go b/pkg/app/app.go index 14f299b..fde0a74 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -16,9 +16,9 @@ type Response struct { } type Pager struct { - Page int `json:"page"` - PageSize int `json:"page_size"` - Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + Total int64 `json:"total"` } func NewResponse(ctx *gin.Context) *Response { @@ -38,7 +38,7 @@ func (r *Response) ToResponseErrors(data interface{}) { r.ToErrorResponse(err) } -func (r *Response) ToResponseList(list interface{}, total int) { +func (r *Response) ToResponseList(list interface{}, total int64) { err := errcode.Success err.WithData(map[string]interface{}{ "list": list,