net.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. /*
  2. * Copyright (c) 2013 IBM Corp.
  3. *
  4. * All rights reserved. This program and the accompanying materials
  5. * are made available under the terms of the Eclipse Public License v1.0
  6. * which accompanies this distribution, and is available at
  7. * http://www.eclipse.org/legal/epl-v10.html
  8. *
  9. * Contributors:
  10. * Seth Hoenig
  11. * Allan Stockdill-Mander
  12. * Mike Robertson
  13. */
  14. package mqtt
  15. import (
  16. "crypto/tls"
  17. "errors"
  18. "fmt"
  19. "net"
  20. "net/http"
  21. "net/url"
  22. "os"
  23. "reflect"
  24. "sync/atomic"
  25. "time"
  26. "github.com/eclipse/paho.mqtt.golang/packets"
  27. "golang.org/x/net/proxy"
  28. "golang.org/x/net/websocket"
  29. )
  30. func signalError(c chan<- error, err error) {
  31. select {
  32. case c <- err:
  33. default:
  34. }
  35. }
  36. func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header) (net.Conn, error) {
  37. switch uri.Scheme {
  38. case "ws":
  39. config, _ := websocket.NewConfig(uri.String(), fmt.Sprintf("http://%s", uri.Host))
  40. config.Protocol = []string{"mqtt"}
  41. config.Header = headers
  42. config.Dialer = &net.Dialer{Timeout: timeout}
  43. conn, err := websocket.DialConfig(config)
  44. if err != nil {
  45. return nil, err
  46. }
  47. conn.PayloadType = websocket.BinaryFrame
  48. return conn, err
  49. case "wss":
  50. config, _ := websocket.NewConfig(uri.String(), fmt.Sprintf("https://%s", uri.Host))
  51. config.Protocol = []string{"mqtt"}
  52. config.TlsConfig = tlsc
  53. config.Header = headers
  54. config.Dialer = &net.Dialer{Timeout: timeout}
  55. conn, err := websocket.DialConfig(config)
  56. if err != nil {
  57. return nil, err
  58. }
  59. conn.PayloadType = websocket.BinaryFrame
  60. return conn, err
  61. case "tcp":
  62. allProxy := os.Getenv("all_proxy")
  63. if len(allProxy) == 0 {
  64. conn, err := net.DialTimeout("tcp", uri.Host, timeout)
  65. if err != nil {
  66. return nil, err
  67. }
  68. return conn, nil
  69. }
  70. proxyDialer := proxy.FromEnvironment()
  71. conn, err := proxyDialer.Dial("tcp", uri.Host)
  72. if err != nil {
  73. return nil, err
  74. }
  75. return conn, nil
  76. case "unix":
  77. conn, err := net.DialTimeout("unix", uri.Host, timeout)
  78. if err != nil {
  79. return nil, err
  80. }
  81. return conn, nil
  82. case "ssl":
  83. fallthrough
  84. case "tls":
  85. fallthrough
  86. case "tcps":
  87. allProxy := os.Getenv("all_proxy")
  88. if len(allProxy) == 0 {
  89. conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
  90. if err != nil {
  91. return nil, err
  92. }
  93. return conn, nil
  94. }
  95. proxyDialer := proxy.FromEnvironment()
  96. conn, err := proxyDialer.Dial("tcp", uri.Host)
  97. if err != nil {
  98. return nil, err
  99. }
  100. tlsConn := tls.Client(conn, tlsc)
  101. err = tlsConn.Handshake()
  102. if err != nil {
  103. conn.Close()
  104. return nil, err
  105. }
  106. return tlsConn, nil
  107. }
  108. return nil, errors.New("Unknown protocol")
  109. }
  110. // actually read incoming messages off the wire
  111. // send Message object into ibound channel
  112. func incoming(c *client) {
  113. var err error
  114. var cp packets.ControlPacket
  115. defer c.workers.Done()
  116. DEBUG.Println(NET, "incoming started")
  117. for {
  118. if cp, err = packets.ReadPacket(c.conn); err != nil {
  119. break
  120. }
  121. DEBUG.Println(NET, "Received Message")
  122. select {
  123. case c.ibound <- cp:
  124. // Notify keepalive logic that we recently received a packet
  125. if c.options.KeepAlive != 0 {
  126. c.lastReceived.Store(time.Now())
  127. }
  128. case <-c.stop:
  129. // This avoids a deadlock should a message arrive while shutting down.
  130. // In that case the "reader" of c.ibound might already be gone
  131. WARN.Println(NET, "incoming dropped a received message during shutdown")
  132. break
  133. }
  134. }
  135. // We received an error on read.
  136. // If disconnect is in progress, swallow error and return
  137. select {
  138. case <-c.stop:
  139. DEBUG.Println(NET, "incoming stopped")
  140. return
  141. // Not trying to disconnect, send the error to the errors channel
  142. default:
  143. ERROR.Println(NET, "incoming stopped with error", err)
  144. signalError(c.errors, err)
  145. return
  146. }
  147. }
  148. // receive a Message object on obound, and then
  149. // actually send outgoing message to the wire
  150. func outgoing(c *client) {
  151. defer c.workers.Done()
  152. DEBUG.Println(NET, "outgoing started")
  153. for {
  154. DEBUG.Println(NET, "outgoing waiting for an outbound message")
  155. select {
  156. case <-c.stop:
  157. DEBUG.Println(NET, "outgoing stopped")
  158. return
  159. case pub := <-c.obound:
  160. msg := pub.p.(*packets.PublishPacket)
  161. if c.options.WriteTimeout > 0 {
  162. c.conn.SetWriteDeadline(time.Now().Add(c.options.WriteTimeout))
  163. }
  164. if err := msg.Write(c.conn); err != nil {
  165. ERROR.Println(NET, "outgoing stopped with error", err)
  166. pub.t.setError(err)
  167. signalError(c.errors, err)
  168. return
  169. }
  170. if c.options.WriteTimeout > 0 {
  171. // If we successfully wrote, we don't want the timeout to happen during an idle period
  172. // so we reset it to infinite.
  173. c.conn.SetWriteDeadline(time.Time{})
  174. }
  175. if msg.Qos == 0 {
  176. pub.t.flowComplete()
  177. }
  178. DEBUG.Println(NET, "obound wrote msg, id:", msg.MessageID)
  179. case msg := <-c.oboundP:
  180. switch msg.p.(type) {
  181. case *packets.SubscribePacket:
  182. msg.p.(*packets.SubscribePacket).MessageID = c.getID(msg.t)
  183. case *packets.UnsubscribePacket:
  184. msg.p.(*packets.UnsubscribePacket).MessageID = c.getID(msg.t)
  185. }
  186. DEBUG.Println(NET, "obound priority msg to write, type", reflect.TypeOf(msg.p))
  187. if err := msg.p.Write(c.conn); err != nil {
  188. ERROR.Println(NET, "outgoing stopped with error", err)
  189. if msg.t != nil {
  190. msg.t.setError(err)
  191. }
  192. signalError(c.errors, err)
  193. return
  194. }
  195. switch msg.p.(type) {
  196. case *packets.DisconnectPacket:
  197. msg.t.(*DisconnectToken).flowComplete()
  198. DEBUG.Println(NET, "outbound wrote disconnect, stopping")
  199. return
  200. }
  201. }
  202. // Reset ping timer after sending control packet.
  203. if c.options.KeepAlive != 0 {
  204. c.lastSent.Store(time.Now())
  205. }
  206. }
  207. }
  208. // receive Message objects on ibound
  209. // store messages if necessary
  210. // send replies on obound
  211. // delete messages from store if necessary
  212. func alllogic(c *client) {
  213. defer c.workers.Done()
  214. DEBUG.Println(NET, "logic started")
  215. for {
  216. DEBUG.Println(NET, "logic waiting for msg on ibound")
  217. select {
  218. case msg := <-c.ibound:
  219. DEBUG.Println(NET, "logic got msg on ibound")
  220. persistInbound(c.persist, msg)
  221. switch m := msg.(type) {
  222. case *packets.PingrespPacket:
  223. DEBUG.Println(NET, "received pingresp")
  224. atomic.StoreInt32(&c.pingOutstanding, 0)
  225. case *packets.SubackPacket:
  226. DEBUG.Println(NET, "received suback, id:", m.MessageID)
  227. token := c.getToken(m.MessageID)
  228. switch t := token.(type) {
  229. case *SubscribeToken:
  230. DEBUG.Println(NET, "granted qoss", m.ReturnCodes)
  231. for i, qos := range m.ReturnCodes {
  232. t.subResult[t.subs[i]] = qos
  233. }
  234. }
  235. token.flowComplete()
  236. c.freeID(m.MessageID)
  237. case *packets.UnsubackPacket:
  238. DEBUG.Println(NET, "received unsuback, id:", m.MessageID)
  239. c.getToken(m.MessageID).flowComplete()
  240. c.freeID(m.MessageID)
  241. case *packets.PublishPacket:
  242. DEBUG.Println(NET, "received publish, msgId:", m.MessageID)
  243. DEBUG.Println(NET, "putting msg on onPubChan")
  244. switch m.Qos {
  245. case 2:
  246. c.incomingPubChan <- m
  247. DEBUG.Println(NET, "done putting msg on incomingPubChan")
  248. case 1:
  249. c.incomingPubChan <- m
  250. DEBUG.Println(NET, "done putting msg on incomingPubChan")
  251. case 0:
  252. select {
  253. case c.incomingPubChan <- m:
  254. case <-c.stop:
  255. }
  256. DEBUG.Println(NET, "done putting msg on incomingPubChan")
  257. }
  258. case *packets.PubackPacket:
  259. DEBUG.Println(NET, "received puback, id:", m.MessageID)
  260. // c.receipts.get(msg.MsgId()) <- Receipt{}
  261. // c.receipts.end(msg.MsgId())
  262. c.getToken(m.MessageID).flowComplete()
  263. c.freeID(m.MessageID)
  264. case *packets.PubrecPacket:
  265. DEBUG.Println(NET, "received pubrec, id:", m.MessageID)
  266. prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket)
  267. prel.MessageID = m.MessageID
  268. select {
  269. case c.oboundP <- &PacketAndToken{p: prel, t: nil}:
  270. case <-c.stop:
  271. }
  272. case *packets.PubrelPacket:
  273. DEBUG.Println(NET, "received pubrel, id:", m.MessageID)
  274. pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
  275. pc.MessageID = m.MessageID
  276. persistOutbound(c.persist, pc)
  277. select {
  278. case c.oboundP <- &PacketAndToken{p: pc, t: nil}:
  279. case <-c.stop:
  280. }
  281. case *packets.PubcompPacket:
  282. DEBUG.Println(NET, "received pubcomp, id:", m.MessageID)
  283. c.getToken(m.MessageID).flowComplete()
  284. c.freeID(m.MessageID)
  285. }
  286. case <-c.stop:
  287. WARN.Println(NET, "logic stopped")
  288. return
  289. }
  290. }
  291. }
  292. func (c *client) ackFunc(packet *packets.PublishPacket) func() {
  293. return func() {
  294. switch packet.Qos {
  295. case 2:
  296. pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
  297. pr.MessageID = packet.MessageID
  298. DEBUG.Println(NET, "putting pubrec msg on obound")
  299. select {
  300. case c.oboundP <- &PacketAndToken{p: pr, t: nil}:
  301. case <-c.stop:
  302. }
  303. DEBUG.Println(NET, "done putting pubrec msg on obound")
  304. case 1:
  305. pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
  306. pa.MessageID = packet.MessageID
  307. DEBUG.Println(NET, "putting puback msg on obound")
  308. persistOutbound(c.persist, pa)
  309. select {
  310. case c.oboundP <- &PacketAndToken{p: pa, t: nil}:
  311. case <-c.stop:
  312. }
  313. DEBUG.Println(NET, "done putting puback msg on obound")
  314. case 0:
  315. // do nothing, since there is no need to send an ack packet back
  316. }
  317. }
  318. }
  319. func errorWatch(c *client) {
  320. defer c.workers.Done()
  321. select {
  322. case <-c.stop:
  323. WARN.Println(NET, "errorWatch stopped")
  324. return
  325. case err := <-c.errors:
  326. ERROR.Println(NET, "error triggered, stopping")
  327. go c.internalConnLost(err)
  328. return
  329. }
  330. }