添加内置限速中间件

This commit is contained in:
zhucheer 2020-02-09 10:57:59 +08:00
parent d98761b49c
commit 3995255a8a
5 changed files with 229 additions and 0 deletions

1
go.mod
View File

@ -8,6 +8,7 @@ require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
github.com/gomodule/redigo v2.0.0+incompatible
github.com/jinzhu/gorm v1.9.11
github.com/juju/ratelimit v1.0.1 // indirect
github.com/zhuCheer/pool v0.2.1
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8
)

2
go.sum
View File

@ -60,6 +60,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M=
github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/juju/ratelimit v1.0.1 h1:+7AIFJVQ0EQgq/K9+0Krm7m530Du7tIz0METWzN0RgY=
github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=

View File

@ -0,0 +1,119 @@
package throttle
import (
"errors"
"fmt"
"gitee.com/zhucheer/orange/app"
"gitee.com/zhucheer/orange/utils"
"github.com/juju/ratelimit"
"net/http"
"sync"
"time"
)
const gcTime = 5 * time.Second
// 限速统计最小时间单元
const rateUnite = 500 * time.Millisecond
type Throttle struct {
MaxQps int64
IpSplit bool
BreakTime time.Duration
requestMaps map[string]*limitItem
mutex sync.Mutex
}
type limitItem struct {
UserTag string
BucketHandler *ratelimit.Bucket
DelaySecond *time.Timer
BreakExpireAt time.Time
}
// NewThrottle 实例化限速中间件 maxRateSecond每秒最大请求数 breakTime限制时长
func NewThrottle(maxQps int64, breakTime time.Duration, ipSplit bool) *Throttle {
return &Throttle{
MaxQps: maxQps, IpSplit: ipSplit, BreakTime: breakTime, requestMaps: make(map[string]*limitItem),
}
}
// Func implements Middleware interface.
func (w Throttle) Func() app.MiddlewareFunc {
return func(next app.HandlerFunc) app.HandlerFunc {
return func(c *app.Context) error {
limitItemInfo := w.getLimter(w.IpSplit, c.OrangeInput.IP())
if limitItemInfo.BreakExpireAt.After(time.Now()) {
return showBreakErr(c)
}
limitItemInfo.DelaySecond.Reset(gcTime + w.BreakTime)
go func(userTag string, delay *time.Timer) {
<-delay.C
w.clearUserTag(userTag)
}(limitItemInfo.UserTag, limitItemInfo.DelaySecond)
takeCount := limitItemInfo.BucketHandler.TakeAvailable(1)
if takeCount < 1 {
limitItemInfo.BreakExpireAt = time.Now().Add(w.BreakTime)
return showBreakErr(c)
}
return next(c)
}
}
}
// showBreakErr
func showBreakErr(c *app.Context) error {
c.HttpError(http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return errors.New(http.StatusText(http.StatusTooManyRequests))
}
// getLimter 获取一个限速 Bucket 对象
func (w *Throttle) getLimter(ipSplit bool, ipAddr string) *limitItem {
userTag := "orangeThrottle"
if ipSplit == true {
userTag = utils.ShortTag(fmt.Sprintf(userTag+"_%s", ipAddr), 1)
}
limiter, exists := w.requestMaps[userTag]
if !exists {
return w.addUserTag(userTag)
}
return limiter
}
// addUserTag 添加一个用户访问标记
func (w *Throttle) addUserTag(userTag string) *limitItem {
w.mutex.Lock()
defer w.mutex.Unlock()
if w.requestMaps == nil {
w.requestMaps = make(map[string]*limitItem)
}
rateCount := int64(time.Second / rateUnite)
quantumUnite := w.MaxQps / rateCount
bucket := ratelimit.NewBucketWithQuantum(rateUnite, w.MaxQps, quantumUnite)
item := &limitItem{
UserTag: userTag,
BucketHandler: bucket,
DelaySecond: time.NewTimer(gcTime + w.BreakTime),
BreakExpireAt: time.Now(),
}
w.requestMaps[userTag] = item
return item
}
// clearUserTag 清理一个用户访问限速对象
func (w *Throttle) clearUserTag(userTag string) {
w.mutex.Lock()
defer w.mutex.Unlock()
delete(w.requestMaps, userTag)
if len(w.requestMaps) == 0 {
w.requestMaps = nil
}
}

View File

@ -0,0 +1,100 @@
package throttle
import (
"context"
"gitee.com/zhucheer/orange/app"
orangerequest "gitee.com/zhucheer/orange/request"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewThrottle(t *testing.T) {
throttle := NewThrottle(5, 3*time.Second, true)
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("REMOTE_ADDR") != "" {
r.RemoteAddr = r.Header.Get("REMOTE_ADDR")
}
ctx := app.NewCtx(context.Background(), w, r)
ctx.OrangeInput = orangerequest.NewInput(r, 20480)
middleWare := throttle.Func()
action := middleWare(func(ctx *app.Context) error {
return nil
})
action(ctx)
}))
defer httpServer.Close()
var defaultClient = &http.Client{}
req, _ := http.NewRequest("GET", httpServer.URL, nil)
defaultClient.Do(req)
req.Header.Add("REMOTE_ADDR", "192.168.1.100")
defaultClient.Do(req)
if len(throttle.requestMaps) != 2 {
t.Error("NewThrottle have error #1")
}
defaultClient.Do(req)
// "192.168.1.100" IP请求限速对象
if throttle.requestMaps["XjL0eP"].BucketHandler.Available() != 3 {
t.Error("NewThrottle have error #2")
}
time.Sleep(time.Second)
defaultClient.Do(req)
defaultClient.Do(req)
defaultClient.Do(req)
defaultClient.Do(req)
defaultClient.Do(req)
if throttle.requestMaps["XjL0eP"].BucketHandler.Available() > 0 {
t.Error("NewThrottle have error #3")
}
time.Sleep(500 * time.Millisecond)
if throttle.requestMaps["XjL0eP"].BucketHandler.Available() < 2 {
t.Error("NewThrottle have error #4")
}
time.Sleep(9 * time.Second)
if len(throttle.requestMaps) > 0 {
t.Error("NewThrottle have error #5")
}
}
func TestNewThrottleConcurency(t *testing.T) {
throttle := NewThrottle(2, time.Minute, true)
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := app.NewCtx(context.Background(), w, r)
ctx.OrangeInput = orangerequest.NewInput(r, 20480)
middleWare := throttle.Func()
action := middleWare(func(ctx *app.Context) error {
return nil
})
action(ctx)
}))
defer httpServer.Close()
var defaultClient = &http.Client{}
req, _ := http.NewRequest("GET", httpServer.URL, nil)
defaultClient.Do(req)
defaultClient.Do(req)
defaultClient.Do(req)
time.Sleep(10 * time.Second)
resp, err := defaultClient.Do(req)
if err != nil {
t.Error("ThrottleConcurency have an error #1")
}
if resp.StatusCode != 429 {
t.Error("ThrottleConcurency break have an error #1")
}
}

View File

@ -2,8 +2,10 @@ package http
import (
"gitee.com/zhucheer/orange/app"
"gitee.com/zhucheer/orange/middlewares/throttle"
"gitee.com/zhucheer/orange/project/http/controller"
"gitee.com/zhucheer/orange/project/http/middleware"
"time"
)
type Route struct {
@ -29,6 +31,11 @@ func (s *Route) ServeMux() {
commonGp.GET("/selectRedis", controller.SelectRedis)
}
rateGp := commonGp.GroupRouter("/rate", throttle.NewThrottle(5, time.Minute, true))
{
rateGp.ALL("/welcome", controller.Welcome)
}
authGp := commonGp.GroupRouter("/auth", middleware.NewAuth())
{
authGp.ALL("/info", controller.AuthCheck)