pubsub.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. package redis
  2. import (
  3. "errors"
  4. "fmt"
  5. "sync"
  6. "time"
  7. "github.com/go-redis/redis/internal"
  8. "github.com/go-redis/redis/internal/pool"
  9. "github.com/go-redis/redis/internal/proto"
  10. )
  11. var errPingTimeout = errors.New("redis: ping timeout")
  12. // PubSub implements Pub/Sub commands bas described in
  13. // http://redis.io/topics/pubsub. Message receiving is NOT safe
  14. // for concurrent use by multiple goroutines.
  15. //
  16. // PubSub automatically reconnects to Redis Server and resubscribes
  17. // to the channels in case of network errors.
  18. type PubSub struct {
  19. opt *Options
  20. newConn func([]string) (*pool.Conn, error)
  21. closeConn func(*pool.Conn) error
  22. mu sync.Mutex
  23. cn *pool.Conn
  24. channels map[string]struct{}
  25. patterns map[string]struct{}
  26. closed bool
  27. exit chan struct{}
  28. cmd *Cmd
  29. chOnce sync.Once
  30. ch chan *Message
  31. ping chan struct{}
  32. }
  33. func (c *PubSub) init() {
  34. c.exit = make(chan struct{})
  35. }
  36. func (c *PubSub) conn() (*pool.Conn, error) {
  37. c.mu.Lock()
  38. cn, err := c._conn(nil)
  39. c.mu.Unlock()
  40. return cn, err
  41. }
  42. func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
  43. if c.closed {
  44. return nil, pool.ErrClosed
  45. }
  46. if c.cn != nil {
  47. return c.cn, nil
  48. }
  49. channels := mapKeys(c.channels)
  50. channels = append(channels, newChannels...)
  51. cn, err := c.newConn(channels)
  52. if err != nil {
  53. return nil, err
  54. }
  55. if err := c.resubscribe(cn); err != nil {
  56. _ = c.closeConn(cn)
  57. return nil, err
  58. }
  59. c.cn = cn
  60. return cn, nil
  61. }
  62. func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error {
  63. return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
  64. return writeCmd(wr, cmd)
  65. })
  66. }
  67. func (c *PubSub) resubscribe(cn *pool.Conn) error {
  68. var firstErr error
  69. if len(c.channels) > 0 {
  70. err := c._subscribe(cn, "subscribe", mapKeys(c.channels))
  71. if err != nil && firstErr == nil {
  72. firstErr = err
  73. }
  74. }
  75. if len(c.patterns) > 0 {
  76. err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns))
  77. if err != nil && firstErr == nil {
  78. firstErr = err
  79. }
  80. }
  81. return firstErr
  82. }
  83. func mapKeys(m map[string]struct{}) []string {
  84. s := make([]string, len(m))
  85. i := 0
  86. for k := range m {
  87. s[i] = k
  88. i++
  89. }
  90. return s
  91. }
  92. func (c *PubSub) _subscribe(
  93. cn *pool.Conn, redisCmd string, channels []string,
  94. ) error {
  95. args := make([]interface{}, 0, 1+len(channels))
  96. args = append(args, redisCmd)
  97. for _, channel := range channels {
  98. args = append(args, channel)
  99. }
  100. cmd := NewSliceCmd(args...)
  101. return c.writeCmd(cn, cmd)
  102. }
  103. func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
  104. c.mu.Lock()
  105. c._releaseConn(cn, err, allowTimeout)
  106. c.mu.Unlock()
  107. }
  108. func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
  109. if c.cn != cn {
  110. return
  111. }
  112. if internal.IsBadConn(err, allowTimeout) {
  113. c._reconnect(err)
  114. }
  115. }
  116. func (c *PubSub) _reconnect(reason error) {
  117. _ = c._closeTheCn(reason)
  118. _, _ = c._conn(nil)
  119. }
  120. func (c *PubSub) _closeTheCn(reason error) error {
  121. if c.cn == nil {
  122. return nil
  123. }
  124. if !c.closed {
  125. internal.Logf("redis: discarding bad PubSub connection: %s", reason)
  126. }
  127. err := c.closeConn(c.cn)
  128. c.cn = nil
  129. return err
  130. }
  131. func (c *PubSub) Close() error {
  132. c.mu.Lock()
  133. defer c.mu.Unlock()
  134. if c.closed {
  135. return pool.ErrClosed
  136. }
  137. c.closed = true
  138. close(c.exit)
  139. err := c._closeTheCn(pool.ErrClosed)
  140. return err
  141. }
  142. // Subscribe the client to the specified channels. It returns
  143. // empty subscription if there are no channels.
  144. func (c *PubSub) Subscribe(channels ...string) error {
  145. c.mu.Lock()
  146. defer c.mu.Unlock()
  147. err := c.subscribe("subscribe", channels...)
  148. if c.channels == nil {
  149. c.channels = make(map[string]struct{})
  150. }
  151. for _, s := range channels {
  152. c.channels[s] = struct{}{}
  153. }
  154. return err
  155. }
  156. // PSubscribe the client to the given patterns. It returns
  157. // empty subscription if there are no patterns.
  158. func (c *PubSub) PSubscribe(patterns ...string) error {
  159. c.mu.Lock()
  160. defer c.mu.Unlock()
  161. err := c.subscribe("psubscribe", patterns...)
  162. if c.patterns == nil {
  163. c.patterns = make(map[string]struct{})
  164. }
  165. for _, s := range patterns {
  166. c.patterns[s] = struct{}{}
  167. }
  168. return err
  169. }
  170. // Unsubscribe the client from the given channels, or from all of
  171. // them if none is given.
  172. func (c *PubSub) Unsubscribe(channels ...string) error {
  173. c.mu.Lock()
  174. defer c.mu.Unlock()
  175. for _, channel := range channels {
  176. delete(c.channels, channel)
  177. }
  178. err := c.subscribe("unsubscribe", channels...)
  179. return err
  180. }
  181. // PUnsubscribe the client from the given patterns, or from all of
  182. // them if none is given.
  183. func (c *PubSub) PUnsubscribe(patterns ...string) error {
  184. c.mu.Lock()
  185. defer c.mu.Unlock()
  186. for _, pattern := range patterns {
  187. delete(c.patterns, pattern)
  188. }
  189. err := c.subscribe("punsubscribe", patterns...)
  190. return err
  191. }
  192. func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
  193. cn, err := c._conn(channels)
  194. if err != nil {
  195. return err
  196. }
  197. err = c._subscribe(cn, redisCmd, channels)
  198. c._releaseConn(cn, err, false)
  199. return err
  200. }
  201. func (c *PubSub) Ping(payload ...string) error {
  202. args := []interface{}{"ping"}
  203. if len(payload) == 1 {
  204. args = append(args, payload[0])
  205. }
  206. cmd := NewCmd(args...)
  207. cn, err := c.conn()
  208. if err != nil {
  209. return err
  210. }
  211. err = c.writeCmd(cn, cmd)
  212. c.releaseConn(cn, err, false)
  213. return err
  214. }
  215. // Subscription received after a successful subscription to channel.
  216. type Subscription struct {
  217. // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
  218. Kind string
  219. // Channel name we have subscribed to.
  220. Channel string
  221. // Number of channels we are currently subscribed to.
  222. Count int
  223. }
  224. func (m *Subscription) String() string {
  225. return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
  226. }
  227. // Message received as result of a PUBLISH command issued by another client.
  228. type Message struct {
  229. Channel string
  230. Pattern string
  231. Payload string
  232. }
  233. func (m *Message) String() string {
  234. return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
  235. }
  236. // Pong received as result of a PING command issued by another client.
  237. type Pong struct {
  238. Payload string
  239. }
  240. func (p *Pong) String() string {
  241. if p.Payload != "" {
  242. return fmt.Sprintf("Pong<%s>", p.Payload)
  243. }
  244. return "Pong"
  245. }
  246. func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
  247. switch reply := reply.(type) {
  248. case string:
  249. return &Pong{
  250. Payload: reply,
  251. }, nil
  252. case []interface{}:
  253. switch kind := reply[0].(string); kind {
  254. case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
  255. return &Subscription{
  256. Kind: kind,
  257. Channel: reply[1].(string),
  258. Count: int(reply[2].(int64)),
  259. }, nil
  260. case "message":
  261. return &Message{
  262. Channel: reply[1].(string),
  263. Payload: reply[2].(string),
  264. }, nil
  265. case "pmessage":
  266. return &Message{
  267. Pattern: reply[1].(string),
  268. Channel: reply[2].(string),
  269. Payload: reply[3].(string),
  270. }, nil
  271. case "pong":
  272. return &Pong{
  273. Payload: reply[1].(string),
  274. }, nil
  275. default:
  276. return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
  277. }
  278. default:
  279. return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
  280. }
  281. }
  282. // ReceiveTimeout acts like Receive but returns an error if message
  283. // is not received in time. This is low-level API and in most cases
  284. // Channel should be used instead.
  285. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
  286. if c.cmd == nil {
  287. c.cmd = NewCmd()
  288. }
  289. cn, err := c.conn()
  290. if err != nil {
  291. return nil, err
  292. }
  293. err = cn.WithReader(timeout, func(rd *proto.Reader) error {
  294. return c.cmd.readReply(rd)
  295. })
  296. c.releaseConn(cn, err, timeout > 0)
  297. if err != nil {
  298. return nil, err
  299. }
  300. return c.newMessage(c.cmd.Val())
  301. }
  302. // Receive returns a message as a Subscription, Message, Pong or error.
  303. // See PubSub example for details. This is low-level API and in most cases
  304. // Channel should be used instead.
  305. func (c *PubSub) Receive() (interface{}, error) {
  306. return c.ReceiveTimeout(0)
  307. }
  308. // ReceiveMessage returns a Message or error ignoring Subscription and Pong
  309. // messages. This is low-level API and in most cases Channel should be used
  310. // instead.
  311. func (c *PubSub) ReceiveMessage() (*Message, error) {
  312. for {
  313. msg, err := c.Receive()
  314. if err != nil {
  315. return nil, err
  316. }
  317. switch msg := msg.(type) {
  318. case *Subscription:
  319. // Ignore.
  320. case *Pong:
  321. // Ignore.
  322. case *Message:
  323. return msg, nil
  324. default:
  325. err := fmt.Errorf("redis: unknown message: %T", msg)
  326. return nil, err
  327. }
  328. }
  329. }
  330. // Channel returns a Go channel for concurrently receiving messages.
  331. // It periodically sends Ping messages to test connection health.
  332. // The channel is closed with PubSub. Receive* APIs can not be used
  333. // after channel is created.
  334. func (c *PubSub) Channel() <-chan *Message {
  335. c.chOnce.Do(c.initChannel)
  336. return c.ch
  337. }
  338. func (c *PubSub) initChannel() {
  339. c.ch = make(chan *Message, 100)
  340. c.ping = make(chan struct{}, 10)
  341. go func() {
  342. var errCount int
  343. for {
  344. msg, err := c.Receive()
  345. if err != nil {
  346. if err == pool.ErrClosed {
  347. close(c.ch)
  348. return
  349. }
  350. if errCount > 0 {
  351. time.Sleep(c.retryBackoff(errCount))
  352. }
  353. errCount++
  354. continue
  355. }
  356. errCount = 0
  357. // Any message is as good as a ping.
  358. select {
  359. case c.ping <- struct{}{}:
  360. default:
  361. }
  362. switch msg := msg.(type) {
  363. case *Subscription:
  364. // Ignore.
  365. case *Pong:
  366. // Ignore.
  367. case *Message:
  368. c.ch <- msg
  369. default:
  370. internal.Logf("redis: unknown message: %T", msg)
  371. }
  372. }
  373. }()
  374. go func() {
  375. const timeout = 5 * time.Second
  376. timer := time.NewTimer(timeout)
  377. timer.Stop()
  378. healthy := true
  379. for {
  380. timer.Reset(timeout)
  381. select {
  382. case <-c.ping:
  383. healthy = true
  384. if !timer.Stop() {
  385. <-timer.C
  386. }
  387. case <-timer.C:
  388. pingErr := c.Ping()
  389. if healthy {
  390. healthy = false
  391. } else {
  392. if pingErr == nil {
  393. pingErr = errPingTimeout
  394. }
  395. c.mu.Lock()
  396. c._reconnect(pingErr)
  397. c.mu.Unlock()
  398. }
  399. case <-c.exit:
  400. return
  401. }
  402. }
  403. }()
  404. }
  405. func (c *PubSub) retryBackoff(attempt int) time.Duration {
  406. return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
  407. }