limit.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package limit
  2. import (
  3. "gd_auth_check/apis"
  4. "gd_auth_check/common.in/cache"
  5. "gd_auth_check/common.in/config"
  6. "gd_auth_check/consts"
  7. "gd_auth_check/errors"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/astaxie/beego/orm"
  12. "github.com/go-redis/redis"
  13. )
  14. const (
  15. reloadScript = `return redis.call('set', 'KEYS[1]', 1,'ex', ARGV[1], 'NX')`
  16. reloadInterval = time.Millisecond * 100
  17. eventKey = "__keyevent@4__:expired"
  18. )
  19. type rateLimiter struct {
  20. routerLimiter map[string]*TokenLimiter
  21. merchantLimiter map[string]*TokenLimiter
  22. mu sync.Mutex
  23. reload bool
  24. }
  25. type rateLimit struct {
  26. AppKey string `json:"app_key"`
  27. Router string `json:"router"`
  28. RateLimit int `json:"rate_limit"`
  29. }
  30. var lim rateLimiter
  31. func getRouterKey(router string) string {
  32. return router
  33. }
  34. func getMerchantKey(appKey, router string) string {
  35. return appKey + ":" + router
  36. }
  37. func constructLimiter(init bool, reloadAll bool) {
  38. routerLimiter := make(map[string]*TokenLimiter)
  39. merchantLimiter := make(map[string]*TokenLimiter)
  40. var list []rateLimit
  41. orm.NewOrm().Raw("SELECT router, rate_limit FROM t_gd_api WHERE rate_limit > 0").QueryRows(&list)
  42. for _, v := range list {
  43. key := getRouterKey(v.Router)
  44. if !init && !reloadAll {
  45. if l, ok := lim.routerLimiter[key]; ok {
  46. if v.RateLimit != l.burst {
  47. routerLimiter[key] = NewTokenLimiter(v.RateLimit, redisRouterKeyPre, key, true, cache.Redis)
  48. continue
  49. }
  50. }
  51. }
  52. routerLimiter[key] = NewTokenLimiter(v.RateLimit, redisRouterKeyPre, key, reloadAll, cache.Redis)
  53. }
  54. lim.routerLimiter = routerLimiter
  55. orm.NewOrm().Raw(`SELECT
  56. router,
  57. t1.rate_limit,
  58. app_key
  59. FROM
  60. t_gd_merchant_child_data_api t1
  61. LEFT JOIN t_gd_merchant_data_api t2 ON t1.merchant_data_api_id = t2.id
  62. LEFT JOIN t_gd_merchants t3 ON t2.merchant_id = t3.id
  63. LEFT JOIN t_gd_api t4 ON t1.api_id = t4.id
  64. WHERE
  65. t1.rate_limit > 0`).QueryRows(&list)
  66. for _, v := range list {
  67. key := getMerchantKey(v.AppKey, v.Router)
  68. if !init && !reloadAll {
  69. if l, ok := lim.merchantLimiter[key]; ok {
  70. if l.burst != v.RateLimit {
  71. merchantLimiter[key] = NewTokenLimiter(v.RateLimit, redisMerchantKeyPre, key, true, cache.Redis)
  72. continue
  73. }
  74. }
  75. }
  76. merchantLimiter[key] = NewTokenLimiter(v.RateLimit, redisMerchantKeyPre, key, reloadAll, cache.Redis)
  77. }
  78. lim.merchantLimiter = merchantLimiter
  79. }
  80. // 初始化限流器
  81. func InitLimiter() {
  82. constructLimiter(true, false)
  83. go subscribe()
  84. }
  85. // 重载限流器
  86. func reload(key string) {
  87. lim.mu.Lock()
  88. defer lim.mu.Unlock()
  89. // 重新构建新的限流数据
  90. if key == pingSuccess {
  91. constructLimiter(true, true)
  92. } else {
  93. constructLimiter(false, false)
  94. }
  95. if lim.reload {
  96. // 正在重载
  97. // 使用时间轮,直到上次重载成功以后在继续执行
  98. ticker := time.NewTicker(reloadInterval)
  99. defer func() {
  100. ticker.Stop()
  101. }()
  102. for range ticker.C {
  103. if !lim.reload {
  104. break
  105. }
  106. }
  107. }
  108. // 使用分布式锁抢占当前key重载
  109. resp, err := cache.Redis.Eval(reloadScript, []string{consts.LimiterReload + key}, []string{"10"})
  110. if err == redis.Nil {
  111. // 锁抢占失败
  112. return
  113. }
  114. v, ok := resp.(string)
  115. if !ok || v != "OK" {
  116. // 锁抢占失败
  117. return
  118. }
  119. lim.reload = true
  120. defer func() {
  121. lim.reload = false
  122. }()
  123. // 重载,重载时会清空全部已用的token
  124. for _, v := range lim.routerLimiter {
  125. if v.update {
  126. v.Reload()
  127. }
  128. }
  129. for _, v := range lim.merchantLimiter {
  130. if v.update {
  131. v.Reload()
  132. }
  133. }
  134. }
  135. // 消息订阅
  136. func subscribe() {
  137. pubsub := cache.Redis.Subscribe(config.Conf.RateLimit.RateLimitChannel)
  138. defer pubsub.Close()
  139. for msg := range pubsub.Channel() {
  140. reload(msg.Payload)
  141. }
  142. }
  143. func routeLimitAllow(router string) (bool, bool, string) {
  144. if v, ok := lim.routerLimiter[getRouterKey(router)]; ok {
  145. allow, token := v.Allow()
  146. return ok, allow, token
  147. }
  148. return false, false, ""
  149. }
  150. func merchantLimitAllow(appKey, router string) (bool, bool, string) {
  151. if v, ok := lim.merchantLimiter[getMerchantKey(appKey, router)]; ok {
  152. allow, token := v.Allow()
  153. return ok, allow, token
  154. }
  155. return false, false, ""
  156. }
  157. func Allow(router, appKey string) (error, *apis.RateLimitToken) {
  158. if exist, allow, rToken := routeLimitAllow(router); exist {
  159. if allow {
  160. // 拿到token
  161. if exist, allow, mToken := merchantLimitAllow(appKey, router); exist {
  162. if allow {
  163. // 拿到商户token
  164. return nil, &apis.RateLimitToken{
  165. AppKey: appKey,
  166. Router: router,
  167. RouterToken: rToken,
  168. MerchantToken: mToken,
  169. }
  170. } else {
  171. // 释放路由token
  172. ReleaseRouterLimiterToken(router, rToken)
  173. return errors.RateLimit, nil
  174. }
  175. } else {
  176. return nil, &apis.RateLimitToken{
  177. AppKey: appKey,
  178. Router: router,
  179. RouterToken: rToken,
  180. MerchantToken: mToken,
  181. }
  182. }
  183. } else {
  184. // 报错
  185. return errors.RateLimit, nil
  186. }
  187. } else if exist, allow, mToken := merchantLimitAllow(appKey, router); exist {
  188. // 判断商户token是否存在
  189. if allow {
  190. // 拿到商户token
  191. return nil, &apis.RateLimitToken{
  192. AppKey: appKey,
  193. Router: router,
  194. MerchantToken: mToken,
  195. }
  196. } else {
  197. // throw error
  198. return errors.RateLimit, nil
  199. }
  200. }
  201. return nil, &apis.RateLimitToken{}
  202. }
  203. func ReleaseRouterLimiterToken(router, token string) {
  204. if v, ok := lim.routerLimiter[getRouterKey(router)]; ok {
  205. v.release(token)
  206. }
  207. }
  208. func ReleaseMerchantLimiterToken(router, appKey, token string) {
  209. if v, ok := lim.merchantLimiter[getMerchantKey(appKey, router)]; ok {
  210. v.release(token)
  211. }
  212. }
  213. func expireEvent() {
  214. pubsub := cache.Redis.Subscribe(eventKey)
  215. defer pubsub.Close()
  216. for msg := range pubsub.Channel() {
  217. if msg.Channel == eventKey {
  218. if msg.Payload != "" && strings.Index(msg.Payload, redisUsedTokenKey) == 0 {
  219. key := strings.Replace(msg.Payload, redisUsedTokenKey, "", 1)
  220. keyArr := strings.Split(key, ":")
  221. switch len(keyArr) {
  222. case 2:
  223. ReleaseRouterLimiterToken(keyArr[0], keyArr[1])
  224. case 3:
  225. ReleaseMerchantLimiterToken(keyArr[1], keyArr[0], keyArr[2])
  226. }
  227. }
  228. }
  229. }
  230. }