package limit import ( "gd_auth_check/apis" "gd_auth_check/common.in/cache" "gd_auth_check/common.in/config" "gd_auth_check/consts" "gd_auth_check/errors" "strings" "sync" "time" "github.com/astaxie/beego/orm" "github.com/go-redis/redis" ) const ( reloadScript = `return redis.call('set', 'KEYS[1]', 1,'ex', ARGV[1], 'NX')` reloadInterval = time.Millisecond * 100 eventKey = "__keyevent@4__:expired" ) type rateLimiter struct { routerLimiter map[string]*TokenLimiter merchantLimiter map[string]*TokenLimiter mu sync.Mutex reload bool } type rateLimit struct { AppKey string `json:"app_key"` Router string `json:"router"` RateLimit int `json:"rate_limit"` } var lim rateLimiter func getRouterKey(router string) string { return router } func getMerchantKey(appKey, router string) string { return appKey + ":" + router } func constructLimiter(init bool, reloadAll bool) { routerLimiter := make(map[string]*TokenLimiter) merchantLimiter := make(map[string]*TokenLimiter) var list []rateLimit orm.NewOrm().Raw("SELECT router, rate_limit FROM t_gd_api WHERE rate_limit > 0").QueryRows(&list) for _, v := range list { key := getRouterKey(v.Router) if !init && !reloadAll { if l, ok := lim.routerLimiter[key]; ok { if v.RateLimit != l.burst { routerLimiter[key] = NewTokenLimiter(v.RateLimit, redisRouterKeyPre, key, true, cache.Redis) continue } } } routerLimiter[key] = NewTokenLimiter(v.RateLimit, redisRouterKeyPre, key, reloadAll, cache.Redis) } lim.routerLimiter = routerLimiter orm.NewOrm().Raw(`SELECT router, t1.rate_limit, app_key FROM t_gd_merchant_child_data_api t1 LEFT JOIN t_gd_merchant_data_api t2 ON t1.merchant_data_api_id = t2.id LEFT JOIN t_gd_merchants t3 ON t2.merchant_id = t3.id LEFT JOIN t_gd_api t4 ON t1.api_id = t4.id WHERE t1.rate_limit > 0`).QueryRows(&list) for _, v := range list { key := getMerchantKey(v.AppKey, v.Router) if !init && !reloadAll { if l, ok := lim.merchantLimiter[key]; ok { if l.burst != v.RateLimit { merchantLimiter[key] = NewTokenLimiter(v.RateLimit, redisMerchantKeyPre, key, true, cache.Redis) continue } } } merchantLimiter[key] = NewTokenLimiter(v.RateLimit, redisMerchantKeyPre, key, reloadAll, cache.Redis) } lim.merchantLimiter = merchantLimiter } // 初始化限流器 func InitLimiter() { constructLimiter(true, false) go subscribe() } // 重载限流器 func reload(key string) { lim.mu.Lock() defer lim.mu.Unlock() // 重新构建新的限流数据 if key == pingSuccess { constructLimiter(true, true) } else { constructLimiter(false, false) } if lim.reload { // 正在重载 // 使用时间轮,直到上次重载成功以后在继续执行 ticker := time.NewTicker(reloadInterval) defer func() { ticker.Stop() }() for range ticker.C { if !lim.reload { break } } } // 使用分布式锁抢占当前key重载 resp, err := cache.Redis.Eval(reloadScript, []string{consts.LimiterReload + key}, []string{"10"}) if err == redis.Nil { // 锁抢占失败 return } v, ok := resp.(string) if !ok || v != "OK" { // 锁抢占失败 return } lim.reload = true defer func() { lim.reload = false }() // 重载,重载时会清空全部已用的token for _, v := range lim.routerLimiter { if v.update { v.Reload() } } for _, v := range lim.merchantLimiter { if v.update { v.Reload() } } } // 消息订阅 func subscribe() { pubsub := cache.Redis.Subscribe(config.Conf.RateLimit.RateLimitChannel) defer pubsub.Close() for msg := range pubsub.Channel() { reload(msg.Payload) } } func routeLimitAllow(router string) (bool, bool, string) { if v, ok := lim.routerLimiter[getRouterKey(router)]; ok { allow, token := v.Allow() return ok, allow, token } return false, false, "" } func merchantLimitAllow(appKey, router string) (bool, bool, string) { if v, ok := lim.merchantLimiter[getMerchantKey(appKey, router)]; ok { allow, token := v.Allow() return ok, allow, token } return false, false, "" } func Allow(router, appKey string) (error, *apis.RateLimitToken) { if exist, allow, rToken := routeLimitAllow(router); exist { if allow { // 拿到token if exist, allow, mToken := merchantLimitAllow(appKey, router); exist { if allow { // 拿到商户token return nil, &apis.RateLimitToken{ AppKey: appKey, Router: router, RouterToken: rToken, MerchantToken: mToken, } } else { // 释放路由token ReleaseRouterLimiterToken(router, rToken) return errors.RateLimit, nil } } else { return nil, &apis.RateLimitToken{ AppKey: appKey, Router: router, RouterToken: rToken, MerchantToken: mToken, } } } else { // 报错 return errors.RateLimit, nil } } else if exist, allow, mToken := merchantLimitAllow(appKey, router); exist { // 判断商户token是否存在 if allow { // 拿到商户token return nil, &apis.RateLimitToken{ AppKey: appKey, Router: router, MerchantToken: mToken, } } else { // throw error return errors.RateLimit, nil } } return nil, &apis.RateLimitToken{} } func ReleaseRouterLimiterToken(router, token string) { if v, ok := lim.routerLimiter[getRouterKey(router)]; ok { v.release(token) } } func ReleaseMerchantLimiterToken(router, appKey, token string) { if v, ok := lim.merchantLimiter[getMerchantKey(appKey, router)]; ok { v.release(token) } } func expireEvent() { pubsub := cache.Redis.Subscribe(eventKey) defer pubsub.Close() for msg := range pubsub.Channel() { if msg.Channel == eventKey { if msg.Payload != "" && strings.Index(msg.Payload, redisUsedTokenKey) == 0 { key := strings.Replace(msg.Payload, redisUsedTokenKey, "", 1) keyArr := strings.Split(key, ":") switch len(keyArr) { case 2: ReleaseRouterLimiterToken(keyArr[0], keyArr[1]) case 3: ReleaseMerchantLimiterToken(keyArr[1], keyArr[0], keyArr[2]) } } } } }