添加内置限速中间件
This commit is contained in:
parent
d98761b49c
commit
3995255a8a
1
go.mod
1
go.mod
|
@ -8,6 +8,7 @@ require (
|
|||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
||||
github.com/gomodule/redigo v2.0.0+incompatible
|
||||
github.com/jinzhu/gorm v1.9.11
|
||||
github.com/juju/ratelimit v1.0.1 // indirect
|
||||
github.com/zhuCheer/pool v0.2.1
|
||||
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8
|
||||
)
|
||||
|
|
2
go.sum
2
go.sum
|
@ -60,6 +60,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
|||
github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M=
|
||||
github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||
github.com/juju/ratelimit v1.0.1 h1:+7AIFJVQ0EQgq/K9+0Krm7m530Du7tIz0METWzN0RgY=
|
||||
github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
package throttle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitee.com/zhucheer/orange/app"
|
||||
"gitee.com/zhucheer/orange/utils"
|
||||
"github.com/juju/ratelimit"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const gcTime = 5 * time.Second
|
||||
|
||||
// 限速统计最小时间单元
|
||||
const rateUnite = 500 * time.Millisecond
|
||||
|
||||
type Throttle struct {
|
||||
MaxQps int64
|
||||
IpSplit bool
|
||||
BreakTime time.Duration
|
||||
|
||||
requestMaps map[string]*limitItem
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type limitItem struct {
|
||||
UserTag string
|
||||
BucketHandler *ratelimit.Bucket
|
||||
DelaySecond *time.Timer
|
||||
BreakExpireAt time.Time
|
||||
}
|
||||
|
||||
// NewThrottle 实例化限速中间件 maxRateSecond:每秒最大请求数 breakTime:限制时长
|
||||
func NewThrottle(maxQps int64, breakTime time.Duration, ipSplit bool) *Throttle {
|
||||
return &Throttle{
|
||||
MaxQps: maxQps, IpSplit: ipSplit, BreakTime: breakTime, requestMaps: make(map[string]*limitItem),
|
||||
}
|
||||
}
|
||||
|
||||
// Func implements Middleware interface.
|
||||
func (w Throttle) Func() app.MiddlewareFunc {
|
||||
return func(next app.HandlerFunc) app.HandlerFunc {
|
||||
return func(c *app.Context) error {
|
||||
limitItemInfo := w.getLimter(w.IpSplit, c.OrangeInput.IP())
|
||||
if limitItemInfo.BreakExpireAt.After(time.Now()) {
|
||||
return showBreakErr(c)
|
||||
}
|
||||
|
||||
limitItemInfo.DelaySecond.Reset(gcTime + w.BreakTime)
|
||||
go func(userTag string, delay *time.Timer) {
|
||||
<-delay.C
|
||||
w.clearUserTag(userTag)
|
||||
}(limitItemInfo.UserTag, limitItemInfo.DelaySecond)
|
||||
|
||||
takeCount := limitItemInfo.BucketHandler.TakeAvailable(1)
|
||||
if takeCount < 1 {
|
||||
limitItemInfo.BreakExpireAt = time.Now().Add(w.BreakTime)
|
||||
return showBreakErr(c)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// showBreakErr
|
||||
func showBreakErr(c *app.Context) error {
|
||||
c.HttpError(http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||
return errors.New(http.StatusText(http.StatusTooManyRequests))
|
||||
}
|
||||
|
||||
// getLimter 获取一个限速 Bucket 对象
|
||||
func (w *Throttle) getLimter(ipSplit bool, ipAddr string) *limitItem {
|
||||
userTag := "orangeThrottle"
|
||||
if ipSplit == true {
|
||||
userTag = utils.ShortTag(fmt.Sprintf(userTag+"_%s", ipAddr), 1)
|
||||
}
|
||||
|
||||
limiter, exists := w.requestMaps[userTag]
|
||||
if !exists {
|
||||
return w.addUserTag(userTag)
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// addUserTag 添加一个用户访问标记
|
||||
func (w *Throttle) addUserTag(userTag string) *limitItem {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
if w.requestMaps == nil {
|
||||
w.requestMaps = make(map[string]*limitItem)
|
||||
}
|
||||
|
||||
rateCount := int64(time.Second / rateUnite)
|
||||
quantumUnite := w.MaxQps / rateCount
|
||||
bucket := ratelimit.NewBucketWithQuantum(rateUnite, w.MaxQps, quantumUnite)
|
||||
|
||||
item := &limitItem{
|
||||
UserTag: userTag,
|
||||
BucketHandler: bucket,
|
||||
DelaySecond: time.NewTimer(gcTime + w.BreakTime),
|
||||
BreakExpireAt: time.Now(),
|
||||
}
|
||||
w.requestMaps[userTag] = item
|
||||
return item
|
||||
}
|
||||
|
||||
// clearUserTag 清理一个用户访问限速对象
|
||||
func (w *Throttle) clearUserTag(userTag string) {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
delete(w.requestMaps, userTag)
|
||||
if len(w.requestMaps) == 0 {
|
||||
w.requestMaps = nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package throttle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gitee.com/zhucheer/orange/app"
|
||||
orangerequest "gitee.com/zhucheer/orange/request"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewThrottle(t *testing.T) {
|
||||
throttle := NewThrottle(5, 3*time.Second, true)
|
||||
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("REMOTE_ADDR") != "" {
|
||||
r.RemoteAddr = r.Header.Get("REMOTE_ADDR")
|
||||
}
|
||||
ctx := app.NewCtx(context.Background(), w, r)
|
||||
ctx.OrangeInput = orangerequest.NewInput(r, 20480)
|
||||
|
||||
middleWare := throttle.Func()
|
||||
action := middleWare(func(ctx *app.Context) error {
|
||||
return nil
|
||||
})
|
||||
action(ctx)
|
||||
|
||||
}))
|
||||
defer httpServer.Close()
|
||||
|
||||
var defaultClient = &http.Client{}
|
||||
req, _ := http.NewRequest("GET", httpServer.URL, nil)
|
||||
defaultClient.Do(req)
|
||||
|
||||
req.Header.Add("REMOTE_ADDR", "192.168.1.100")
|
||||
defaultClient.Do(req)
|
||||
|
||||
if len(throttle.requestMaps) != 2 {
|
||||
t.Error("NewThrottle have error #1")
|
||||
}
|
||||
|
||||
defaultClient.Do(req)
|
||||
// "192.168.1.100" IP请求限速对象
|
||||
if throttle.requestMaps["XjL0eP"].BucketHandler.Available() != 3 {
|
||||
t.Error("NewThrottle have error #2")
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
defaultClient.Do(req)
|
||||
defaultClient.Do(req)
|
||||
defaultClient.Do(req)
|
||||
defaultClient.Do(req)
|
||||
defaultClient.Do(req)
|
||||
if throttle.requestMaps["XjL0eP"].BucketHandler.Available() > 0 {
|
||||
t.Error("NewThrottle have error #3")
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if throttle.requestMaps["XjL0eP"].BucketHandler.Available() < 2 {
|
||||
t.Error("NewThrottle have error #4")
|
||||
}
|
||||
time.Sleep(9 * time.Second)
|
||||
|
||||
if len(throttle.requestMaps) > 0 {
|
||||
t.Error("NewThrottle have error #5")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewThrottleConcurency(t *testing.T) {
|
||||
throttle := NewThrottle(2, time.Minute, true)
|
||||
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := app.NewCtx(context.Background(), w, r)
|
||||
ctx.OrangeInput = orangerequest.NewInput(r, 20480)
|
||||
|
||||
middleWare := throttle.Func()
|
||||
action := middleWare(func(ctx *app.Context) error {
|
||||
return nil
|
||||
})
|
||||
action(ctx)
|
||||
}))
|
||||
|
||||
defer httpServer.Close()
|
||||
|
||||
var defaultClient = &http.Client{}
|
||||
req, _ := http.NewRequest("GET", httpServer.URL, nil)
|
||||
defaultClient.Do(req)
|
||||
defaultClient.Do(req)
|
||||
defaultClient.Do(req)
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
resp, err := defaultClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
t.Error("ThrottleConcurency have an error #1")
|
||||
}
|
||||
|
||||
if resp.StatusCode != 429 {
|
||||
t.Error("ThrottleConcurency break have an error #1")
|
||||
}
|
||||
}
|
|
@ -2,8 +2,10 @@ package http
|
|||
|
||||
import (
|
||||
"gitee.com/zhucheer/orange/app"
|
||||
"gitee.com/zhucheer/orange/middlewares/throttle"
|
||||
"gitee.com/zhucheer/orange/project/http/controller"
|
||||
"gitee.com/zhucheer/orange/project/http/middleware"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
|
@ -29,6 +31,11 @@ func (s *Route) ServeMux() {
|
|||
commonGp.GET("/selectRedis", controller.SelectRedis)
|
||||
}
|
||||
|
||||
rateGp := commonGp.GroupRouter("/rate", throttle.NewThrottle(5, time.Minute, true))
|
||||
{
|
||||
rateGp.ALL("/welcome", controller.Welcome)
|
||||
}
|
||||
|
||||
authGp := commonGp.GroupRouter("/auth", middleware.NewAuth())
|
||||
{
|
||||
authGp.ALL("/info", controller.AuthCheck)
|
||||
|
|
Loading…
Reference in New Issue