add: limitable and timeout middleware

This commit is contained in:
viletyy 2021-06-13 22:47:19 +08:00
parent ed7cfb418c
commit 17b4f99a50
7 changed files with 164 additions and 16 deletions

1
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/go-playground/validator/v10 v10.6.1
github.com/go-redis/redis v6.15.9+incompatible
github.com/jinzhu/gorm v1.9.16
github.com/juju/ratelimit v1.0.1
github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible
github.com/lestrrat-go/strftime v1.0.4 // indirect
github.com/onsi/ginkgo v1.16.4 // indirect

View File

@ -1,10 +1,10 @@
/*
* @Date: 2021-06-10 18:58:25
* @LastEditors: viletyy
* @LastEditTime: 2021-06-11 15:44:18
* @FilePath: /potato/internal/controller/api/v1/auth.go
* @LastEditTime: 2021-06-13 22:41:01
* @FilePath: /potato/internal/controller/api/auth.go
*/
package v1
package api
import (
"github.com/gin-gonic/gin"
@ -21,7 +21,7 @@ import (
// @Param app_key formData string true "app key"
// @Param app_secret formData string true "app secret"
// @Success 200 {object} errcode.Error "请求成功"
// @Router /v1/auth [post]
// @Router /auth [post]
func GetAuth(c *gin.Context) {
param := service.AuthRequest{}
response := app.NewResponse(c)

View File

@ -0,0 +1,24 @@
/*
* @Date: 2021-06-13 22:35:30
* @LastEditors: viletyy
* @LastEditTime: 2021-06-13 22:37:09
* @FilePath: /potato/internal/middleware/context_timeout.go
*/
package middleware
import (
"context"
"time"
"github.com/gin-gonic/gin"
)
func ContextTimeout(t time.Duration) func(c *gin.Context) {
return func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c.Request.Context(), t)
defer cancel()
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}

View File

@ -0,0 +1,31 @@
/*
* @Date: 2021-06-13 22:27:24
* @LastEditors: viletyy
* @LastEditTime: 2021-06-13 22:29:27
* @FilePath: /potato/internal/middleware/limiter.go
*/
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/viletyy/potato/pkg/app"
"github.com/viletyy/potato/pkg/errcode"
"github.com/viletyy/potato/pkg/limiter"
)
func RateLimiter(l limiter.LimiterInterface) gin.HandlerFunc {
return func(c *gin.Context) {
key := l.Key(c)
if bucket, ok := l.GetBucket(key); ok {
count := bucket.TakeAvailable(1)
if count == 0 {
response := app.NewResponse(c)
response.ToErrorResponse(errcode.TooManyRequests)
c.Abort()
return
}
}
c.Next()
}
}

View File

@ -1,13 +1,14 @@
/*
* @Date: 2021-03-21 19:54:57
* @LastEditors: viletyy
* @LastEditTime: 2021-06-13 22:03:27
* @LastEditTime: 2021-06-13 22:46:46
* @FilePath: /potato/internal/routers/router.go
*/
package routers
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
_ "github.com/swaggo/gin-swagger"
@ -16,29 +17,42 @@ import (
_ "github.com/swaggo/gin-swagger/swaggerFiles"
_ "github.com/viletyy/potato/docs"
"github.com/viletyy/potato/global"
"github.com/viletyy/potato/internal/controller/api"
v1 "github.com/viletyy/potato/internal/controller/api/v1"
"github.com/viletyy/potato/internal/middleware"
"github.com/viletyy/potato/pkg/limiter"
)
var (
Engine = gin.Default()
V1RouterGroup = Engine.Group("../api/v1")
Engine = gin.Default()
V1RouterGroup = Engine.Group("../api/v1")
methodLimiters = limiter.NewMethodLimiter().AddBuckets(limiter.LimiterBucketRule{
Key: "/auth",
FillInterval: time.Second,
Capacity: 10,
Quantum: 10,
})
)
func InitRouter() *gin.Engine {
Engine.Use(gin.Logger())
gin.SetMode(global.GO_CONFIG.App.RunMode) // 设置运行模式
Engine.Use(gin.Recovery())
Engine.Use(middleware.Translations())
if global.GO_CONFIG.App.RunMode == "debug" {
Engine.Use(gin.Logger()) // 设置log
Engine.Use(gin.Recovery()) // 设置recovery
} else {
Engine.Use(middleware.AccessLog())
Engine.Use(middleware.Recovery())
}
Engine.Use(middleware.AppInfo()) // 设置app信息
Engine.Use(middleware.RateLimiter(methodLimiters)) // 设置限流控制
Engine.Use(middleware.ContextTimeout(60 * time.Second)) // 设置统一超时控制
Engine.Use(middleware.Translations()) // 设置自定义验证
Engine.Use(middleware.CORS()) // 设置跨域
gin.SetMode(global.GO_CONFIG.App.RunMode)
Engine.Use(middleware.CORS())
Engine.Use(middleware.AccessLog())
Engine.Use(middleware.AppInfo())
Engine.StaticFS("/static", http.Dir(global.GO_CONFIG.App.UploadSavePath))
Engine.POST("/api/v1/auth", v1.GetAuth)
Engine.POST("/api/auth", api.GetAuth)
Engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
V1InitModule()

31
pkg/limiter/limiter.go Normal file
View File

@ -0,0 +1,31 @@
/*
* @Date: 2021-06-13 22:06:52
* @LastEditors: viletyy
* @LastEditTime: 2021-06-13 22:13:35
* @FilePath: /potato/pkg/limiter/limiter.go
*/
package limiter
import (
"time"
"github.com/gin-gonic/gin"
"github.com/juju/ratelimit"
)
type LimiterInterface interface {
Key(c *gin.Context) string // 获取对应的限流器的键值对名称。
GetBucket(key string) (*ratelimit.Bucket, bool) // 获取令牌桶
AddBuckets(rules ...LimiterBucketRule) LimiterInterface // 新增多个令牌桶
}
type Limiter struct {
limiterBuckets map[string]*ratelimit.Bucket
}
type LimiterBucketRule struct {
Key string // 自定义键值对名称
FillInterval time.Duration // 间隔多久时间放N个令牌
Capacity int64 // 令牌桶的容量
Quantum int64 // 每次到达间隔时间后所放的具体令牌数量
}

View File

@ -0,0 +1,47 @@
/*
* @Date: 2021-06-13 22:13:53
* @LastEditors: viletyy
* @LastEditTime: 2021-06-13 22:26:09
* @FilePath: /potato/pkg/limiter/method_limiter.go
*/
package limiter
import (
"strings"
"github.com/gin-gonic/gin"
"github.com/juju/ratelimit"
)
type MethodLimiter struct {
*Limiter
}
func NewMethodLimiter() LimiterInterface {
return MethodLimiter{
Limiter: &Limiter{limiterBuckets: make(map[string]*ratelimit.Bucket)},
}
}
func (l MethodLimiter) Key(c *gin.Context) string {
uri := c.Request.RequestURI
index := strings.Index(uri, "?")
if index == -1 {
return uri
}
return uri[:index]
}
func (l MethodLimiter) GetBucket(key string) (*ratelimit.Bucket, bool) {
bucket, ok := l.limiterBuckets[key]
return bucket, ok
}
func (l MethodLimiter) AddBuckets(rules ...LimiterBucketRule) LimiterInterface {
for _, rule := range rules {
if _, ok := l.limiterBuckets[rule.Key]; !ok {
l.limiterBuckets[rule.Key] = ratelimit.NewBucketWithQuantum(rule.FillInterval, rule.Capacity, rule.Quantum)
}
}
return l
}