From c24137ba1f780ac495de74c122433beb502d6c31 Mon Sep 17 00:00:00 2001 From: zhucheer Date: Wed, 6 Nov 2019 19:19:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0redis=E8=B0=83=E7=94=A8?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E8=BF=9E=E6=8E=A5=E6=B1=A0=E5=9B=9E?= =?UTF-8?q?=E6=94=B6=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/app.go | 4 + config/config.toml | 26 +++--- database/db.go | 76 +++++++++++++++++ database/mysql.go | 89 +++++++------------- database/redis.go | 167 +++++++++++++++++++++++++++++++++++++ outapp/app.go | 1 + outapp/controller/index.go | 31 +++++-- queue/list.go | 63 ++++++++++++++ queue/queue_test.go | 99 ++++++++++++++++++++++ request/request_test.go | 2 +- 10 files changed, 478 insertions(+), 80 deletions(-) create mode 100644 database/db.go create mode 100644 database/redis.go create mode 100644 queue/list.go create mode 100644 queue/queue_test.go diff --git a/app/app.go b/app/app.go index 5daf72b..20c761f 100644 --- a/app/app.go +++ b/app/app.go @@ -29,6 +29,7 @@ func AppStart(appSrv AppSrv) { // 注册mysql database.NewMysql().RegisterAll() + database.NewRedis().RegisterAll() // 启动http服务 startHttpSrv() @@ -110,6 +111,9 @@ func directOutput(writer http.ResponseWriter, code int, content []byte) { func httpAfterDo(c *Context) error { c.session.SessionRelease(c.response) + // 回收数据库连接 + go database.PullChanDB() + // 最后输出body c.response.Write(c.responseBody.Bytes()) c.responseBody.Reset() diff --git a/config/config.toml b/config/config.toml index 44e406d..ac96404 100644 --- a/config/config.toml +++ b/config/config.toml @@ -14,24 +14,22 @@ isOpen = true timeout = 1800 [database] - initCap = 10 - maxCap = 200 - idleTimeout = 10 + initCap = 2 + maxCap = 5 + idleTimeout = 5 debug = true - [database.conn] - [database.conn.default] + [database.mysql] + [database.mysql.default] addr = "192.168.137.100:3306" username = "zhuqi" password = "123456" dbname = "weixin" - -[redis] - initCap = 10 - maxCap = 200 - idleTimeout = 10 - [redis.conn] - [redis.conn.default] - addr = "192.168.137.100:6379" - password = "123456" + [database.redis] + [database.redis.default] + addr = "192.168.137.101:6379" + password = "rw:Ql46" dbnum = 5 + [database.redis.dxx] + addr = "192.168.137.100:6379" + dbnum = 7 diff --git a/database/db.go b/database/db.go new file mode 100644 index 0000000..619d2f0 --- /dev/null +++ b/database/db.go @@ -0,0 +1,76 @@ +package database + +import ( + "gitee.com/zhucheer/orange/cfg" + "gitee.com/zhucheer/orange/logger" + "github.com/zhuCheer/pool" +) + +type DataBase interface { + RegisterAll() + Register(string) + insertPool(string, pool.Pool) + getDB(string) (interface{}, error) + putDB(string, interface{}) error +} + +type dbChan struct { + dbType string + name string + conn interface{} +} + +// PullChanDB 持续监听将chan中的连接异步放回连接池 +func PullChanDB() { + if mysqlConn.count+redisConn.count == 0 { + return + } + + for { + if mysqlConn.connList.Len() == 0 { + break + } + conn := mysqlConn.connList.LPop() + if conn == nil { + break + } + + dbConn := conn.(dbChan) + err := mysqlConn.putDB(dbConn.name, dbConn.conn) + if err != nil { + logger.Error("mysql conn put back err:%v", err) + } + } + + for { + if redisConn.connList.Len() == 0 { + break + } + conn := redisConn.connList.LPop() + if conn == nil { + break + } + + dbConn := conn.(dbChan) + err := redisConn.putDB(dbConn.name, dbConn.conn) + if err != nil { + logger.Error("redis conn put back err:%v", err) + } + } +} + +func getDBIntConfig(dbtype, name, key string) int { + exists := cfg.Config.Exists("database." + dbtype + "." + name + "." + key) + if exists == false { + return cfg.Config.GetInt("database." + key) + } + return cfg.Config.GetInt("database." + dbtype + "." + name + "." + key) +} + +func getBoolConfig(dbtype, name, key string) bool { + exists := cfg.Config.Exists("database." + dbtype + "." + name + "." + key) + if exists == false { + return cfg.Config.GetBool("database." + key) + } + return cfg.Config.GetBool("database." + dbtype + "." + name + "." + key) +} diff --git a/database/mysql.go b/database/mysql.go index d2824c3..5e379ad 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -5,6 +5,7 @@ import ( "fmt" "gitee.com/zhucheer/orange/cfg" "gitee.com/zhucheer/orange/logger" + "gitee.com/zhucheer/orange/queue" "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/mysql" "github.com/zhuCheer/pool" @@ -16,53 +17,45 @@ var mysqlConn *MysqlDB type MysqlDB struct { connPool map[string]pool.Pool - connChan chan gormChan + connList *queue.Queue + count int lock sync.Mutex } -type gormChan struct { - name string - conn *gorm.DB -} - // NewMysql 初始化mysql连接 -func NewMysql() *MysqlDB { +func NewMysql() DataBase { if mysqlConn != nil { return mysqlConn } mysqlConn = &MysqlDB{ connPool: make(map[string]pool.Pool, 0), - connChan: make(chan gormChan, 0), + connList: queue.NewQueue(), } return mysqlConn } // 注册所有已配置的mysql func (my *MysqlDB) RegisterAll() { - databaseConfig := cfg.Config.GetMap("database.conn") - - if len(databaseConfig) > 0 { - go pullChanDB() - } + databaseConfig := cfg.Config.GetMap("database.mysql") + my.count = len(databaseConfig) for dd := range databaseConfig { - my.RegisterMysql(dd) + my.Register(dd) } } // NewMysql 注册一个mysql配置 -func (my *MysqlDB) RegisterMysql(name string) { +func (my *MysqlDB) Register(name string) { + addr := cfg.Config.GetString("database.mysql." + name + ".addr") + username := cfg.Config.GetString("database.mysql." + name + ".username") + password := cfg.Config.GetString("database.mysql." + name + ".password") + dbname := cfg.Config.GetString("database.mysql." + name + ".dbname") - addr := cfg.Config.GetString("database.conn." + name + ".addr") - username := cfg.Config.GetString("database.conn." + name + ".username") - password := cfg.Config.GetString("database.conn." + name + ".password") - dbname := cfg.Config.GetString("database.conn." + name + ".dbname") - - initCap := getDBIntConfig(name, "initCap") - maxCap := getDBIntConfig(name, "maxCap") - idleTimeout := getDBIntConfig(name, "idleTimeout") - isDebug := getBoolConfig(name, "debug") + initCap := getDBIntConfig("mysql", name, "initCap") + maxCap := getDBIntConfig("mysql", name, "maxCap") + idleTimeout := getDBIntConfig("mysql", name, "idleTimeout") + isDebug := getBoolConfig("mysql", name, "debug") if initCap == 0 || maxCap == 0 || idleTimeout == 0 { logger.Error("database config is error initCap,maxCap,idleTimeout should be gt 0") @@ -119,25 +112,24 @@ func (my *MysqlDB) insertPool(name string, p pool.Pool) { } // getDB 从连接池获取一个连接 -func (my *MysqlDB) getDB(name string) (db *gorm.DB, err error) { +func (my *MysqlDB) getDB(name string) (conn interface{}, err error) { if _, ok := my.connPool[name]; !ok { return nil, errors.New("no mysql connect") } - conn, err := my.connPool[name].Get() + conn, err = my.connPool[name].Get() if err != nil { return nil, errors.New(fmt.Sprintf("mysql get connect err:%v", err)) } - db = conn.(*gorm.DB) go func() { - my.connChan <- gormChan{name, db} + my.connList.RPush(dbChan{"mysql", name, conn}) }() - return db, nil + return conn, nil } // putDB 将连接放回连接池 -func (my *MysqlDB) putDB(name string, db *gorm.DB) (err error) { +func (my *MysqlDB) putDB(name string, db interface{}) (err error) { if _, ok := my.connPool[name]; !ok { return errors.New("no mysql connect") } @@ -146,37 +138,16 @@ func (my *MysqlDB) putDB(name string, db *gorm.DB) (err error) { return } -func getDBIntConfig(name, key string) int { - exists := cfg.Config.Exists("database.conn." + name + "." + key) - if exists == false { - return cfg.Config.GetInt("database." + key) - } - return cfg.Config.GetInt("database.conn." + name + "." + key) -} - -func getBoolConfig(name, key string) bool { - exists := cfg.Config.Exists("database.conn." + name + "." + key) - if exists == false { - return cfg.Config.GetBool("database." + key) - } - return cfg.Config.GetBool("database.conn." + name + "." + key) -} - -// pullChanDB 持续监听将chan中的连接异步放回连接池 -func pullChanDB() { - for { - select { - case gorm := <-mysqlConn.connChan: - mysqlConn.putDB(gorm.name, gorm.conn) - } - } -} - -// GetDB 获取一个db连接 -func GetDB(name string) (db *gorm.DB, err error) { +// GetMysql 获取一个mysql db连接 +func GetMysql(name string) (db *gorm.DB, err error) { if mysqlConn == nil { return nil, errors.New("db connect is nil") } - return mysqlConn.getDB(name) + conn, err := mysqlConn.getDB(name) + if err != nil { + return nil, err + } + db = conn.(*gorm.DB) + return db, nil } diff --git a/database/redis.go b/database/redis.go new file mode 100644 index 0000000..bb12b0d --- /dev/null +++ b/database/redis.go @@ -0,0 +1,167 @@ +package database + +import ( + "errors" + "fmt" + "gitee.com/zhucheer/orange/cfg" + "gitee.com/zhucheer/orange/logger" + "gitee.com/zhucheer/orange/queue" + "github.com/garyburd/redigo/redis" + "github.com/zhuCheer/pool" + "sync" + "time" +) + +var redisConn *RedisDB + +type RedisDB struct { + connPool map[string]pool.Pool + connList *queue.Queue + count int + lock sync.Mutex +} + +// NewRedis 初始化 redis 连接 +func NewRedis() DataBase { + if redisConn != nil { + return redisConn + } + + redisConn = &RedisDB{ + connPool: make(map[string]pool.Pool, 0), + connList: queue.NewQueue(), + } + return redisConn +} + +// 注册所有已配置的mysql +func (re *RedisDB) RegisterAll() { + logger.Info("##register all redis##") + redisConfig := cfg.Config.GetMap("database.redis") + + re.count = len(redisConfig) + for dd := range redisConfig { + re.Register(dd) + } +} + +// RegisterRedis 注册一个redis配置 +func (re *RedisDB) Register(name string) { + addr := cfg.Config.GetString("database.redis." + name + ".addr") + password := cfg.Config.GetString("database.redis." + name + ".password") + dbnum := cfg.Config.GetInt("database.redis." + name + ".dbnum") + + initCap := getDBIntConfig("redis", name, "initCap") + maxCap := getDBIntConfig("redis", name, "maxCap") + idleTimeout := getDBIntConfig("redis", name, "idleTimeout") + + // connRedis 建立连接 + connRedis := func() (interface{}, error) { + conn, err := redis.Dial("tcp", addr) + if err != nil { + return nil, err + } + if password != "" { + _, err := conn.Do("AUTH", password) + if err != nil { + return nil, err + } + } + if dbnum > 0 { + _, err := conn.Do("SELECT", dbnum) + if err != nil { + return nil, err + } + } + return conn, err + } + + // closeRedis 关闭连接 + closeRedis := func(v interface{}) error { + return v.(redis.Conn).Close() + } + + // pingRedis 检测连接连通性 + pingRedis := func(v interface{}) error { + conn := v.(redis.Conn) + + val, err := redis.String(conn.Do("PING")) + + if err != nil { + return err + } + if val != "PONG" { + return errors.New("redis ping is error ping => " + val) + } + + return nil + } + + //创建一个连接池: 初始化5,最大连接30 + p, err := pool.NewChannelPool(&pool.Config{ + InitialCap: initCap, + MaxCap: maxCap, + Factory: connRedis, + Close: closeRedis, + Ping: pingRedis, + //连接最大空闲时间,超过该时间的连接 将会关闭,可避免空闲时连接EOF,自动失效的问题 + IdleTimeout: time.Duration(idleTimeout) * time.Second, + }) + if err != nil { + logger.Error("register redis conn [%s] error:%v", name, err) + return + } + re.insertPool(name, p) + +} + +// insertPool 将连接池插入map +func (re *RedisDB) insertPool(name string, p pool.Pool) { + if re.connPool == nil { + re.connPool = make(map[string]pool.Pool, 0) + } + + re.lock.Lock() + defer re.lock.Unlock() + re.connPool[name] = p +} + +// getDB 从连接池获取一个连接 +func (re *RedisDB) getDB(name string) (conn interface{}, err error) { + if _, ok := re.connPool[name]; !ok { + return nil, errors.New("no redis connect") + } + conn, err = re.connPool[name].Get() + if err != nil { + return nil, errors.New(fmt.Sprintf("redis get connect err:%v", err)) + } + + go func() { + re.connList.RPush(dbChan{"redis", name, conn}) + }() + + return conn, nil +} + +// putDB 将连接放回连接池 +func (re *RedisDB) putDB(name string, db interface{}) (err error) { + if _, ok := re.connPool[name]; !ok { + return errors.New("no redis connect") + } + err = re.connPool[name].Put(db) + + return +} + +// GetRedis 获取一个mysql db连接 +func GetRedis(name string) (db redis.Conn, err error) { + if redisConn == nil { + return nil, errors.New("db connect is nil") + } + conn, err := redisConn.getDB(name) + if err != nil { + return nil, err + } + db = conn.(redis.Conn) + return db, nil +} diff --git a/outapp/app.go b/outapp/app.go index dfb7cab..eed9fde 100644 --- a/outapp/app.go +++ b/outapp/app.go @@ -16,6 +16,7 @@ func (s *Route) ServeMux() { mm.ALL("/hello", controller.Test) mm.GET("/hello2", controller.Test2) + mm.GET("/redis", controller.RedisTT) mm.GET("/ttc", func(ctx *app.Context) error { logger.Info("tcc is commint") diff --git a/outapp/controller/index.go b/outapp/controller/index.go index c0c55ce..abefa51 100644 --- a/outapp/controller/index.go +++ b/outapp/controller/index.go @@ -1,6 +1,7 @@ package controller import ( + "errors" "fmt" "gitee.com/zhucheer/orange/app" "gitee.com/zhucheer/orange/database" @@ -23,10 +24,28 @@ func Test(c *app.Context) error { } func Test2(c *app.Context) error { + //database.RedisDo() - DoDB() + info := DoDB() return c.ToJson(map[string]interface{}{ - "tt": "222", + "tt": "222", + "cate": info.Name, + }) +} + +func RedisTT(c *app.Context) error { + + redisxxx, err := database.GetRedis("default") + + if err != nil { + fmt.Println(err) + return errors.New("xxxxxxx") + } + + redisxxx.Do("SET", "zhutttt", "1212211") + + return c.ToJson(map[string]interface{}{ + "tt": "redis", }) } @@ -34,15 +53,15 @@ type ImgCate struct { ID uint `gorm:"primary_key"` Name string } // 默认表名是`users` -func DoDB() { - db, err := database.GetDB("default") +func DoDB() *ImgCate { + db, err := database.GetMysql("default") if err != nil { fmt.Println("db connect error", err) - return + return nil } info := &ImgCate{} db.Table("qi_imgcate").Where("short = ?", "wzry").First(&info) - fmt.Println(info) + return info } diff --git a/queue/list.go b/queue/list.go new file mode 100644 index 0000000..898a25e --- /dev/null +++ b/queue/list.go @@ -0,0 +1,63 @@ +package queue + +import ( + "container/list" + "sync" +) + +type Queue struct { + l *list.List + m sync.Mutex +} + +func NewQueue() *Queue { + return &Queue{l: list.New()} +} + +func (q *Queue) LPush(v interface{}) { + if v == nil { + return + } + q.m.Lock() + defer q.m.Unlock() + q.l.PushFront(v) +} + +func (q *Queue) RPush(v interface{}) { + if v == nil { + return + } + q.m.Lock() + defer q.m.Unlock() + q.l.PushBack(v) +} + +func (q *Queue) LPop() interface{} { + q.m.Lock() + defer q.m.Unlock() + + element := q.l.Front() + if element == nil { + return nil + } + + q.l.Remove(element) + return element.Value +} + +func (q *Queue) RPop() interface{} { + q.m.Lock() + defer q.m.Unlock() + + element := q.l.Back() + if element == nil { + return nil + } + + q.l.Remove(element) + return element.Value +} + +func (q *Queue) Len() int { + return q.l.Len() +} diff --git a/queue/queue_test.go b/queue/queue_test.go new file mode 100644 index 0000000..b035f1c --- /dev/null +++ b/queue/queue_test.go @@ -0,0 +1,99 @@ +package queue + +import ( + "fmt" + "testing" + "time" +) + +func TestRPushQueue(t *testing.T) { + + ll := NewQueue() + + ll.RPush("1") + ll.RPush("2") + ll.RPush("3") + + go func() { + ll.RPush("4") + }() + go func() { + ll.RPush("5") + }() + + go func() { + ll.RPush("6") + }() + + time.Sleep(1 * time.Second) + + if ll.Len() != 6 { + t.Error("list Len() do error #1") + } + + listVal := fmt.Sprintf("num=>%v,%v,%v", ll.LPop(), ll.LPop(), ll.LPop()) + if listVal != "num=>1,2,3" { + t.Error("list do error #2") + } + + if ll.Len() != 3 { + t.Error("list Len() do error #3") + } + + ll.LPop() + ll.LPop() + ll.LPop() + c := ll.LPop() + + if c != nil { + t.Error("list LPop() do error #4") + } + + time.Sleep(1 * time.Second) +} + +func TestLPushQueue(t *testing.T) { + + ll := NewQueue() + + ll.LPush("1") + ll.LPush("2") + ll.LPush("3") + + go func() { + ll.LPush("4") + }() + go func() { + ll.LPush("5") + }() + + go func() { + ll.LPush("6") + }() + + time.Sleep(1 * time.Second) + + if ll.Len() != 6 { + t.Error("list Len() do error #1") + } + + listVal := fmt.Sprintf("num=>%v,%v,%v", ll.RPop(), ll.RPop(), ll.RPop()) + if listVal != "num=>1,2,3" { + t.Error("list do error #2") + } + + if ll.Len() != 3 { + t.Error("list Len() do error #3") + } + + ll.RPop() + ll.RPop() + ll.RPop() + c := ll.RPop() + + if c != nil { + t.Error("list RPop() do error #4") + } + + time.Sleep(1 * time.Second) +} diff --git a/request/request_test.go b/request/request_test.go index 0c4ae1d..1a03631 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -11,7 +11,7 @@ import ( func TestNewInput(t *testing.T) { httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - inputHandler := NewInput(r) + inputHandler := NewInput(r, 2048) if inputHandler.Protocol() != "HTTP/1.1" { t.Error("Protocol is error")