packet_handler_map_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. package quic
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "errors"
  6. "net"
  7. "time"
  8. "github.com/golang/mock/gomock"
  9. "github.com/lucas-clemente/quic-go/internal/protocol"
  10. "github.com/lucas-clemente/quic-go/internal/utils"
  11. "github.com/lucas-clemente/quic-go/internal/wire"
  12. . "github.com/onsi/ginkgo"
  13. . "github.com/onsi/gomega"
  14. )
  15. var _ = Describe("Packet Handler Map", func() {
  16. var (
  17. handler *packetHandlerMap
  18. conn *mockPacketConn
  19. connIDLen int
  20. statelessResetKey []byte
  21. )
  22. getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte {
  23. buf := &bytes.Buffer{}
  24. Expect((&wire.ExtendedHeader{
  25. Header: wire.Header{
  26. IsLongHeader: true,
  27. Type: protocol.PacketTypeHandshake,
  28. DestConnectionID: connID,
  29. Length: length,
  30. Version: protocol.VersionTLS,
  31. },
  32. PacketNumberLen: protocol.PacketNumberLen2,
  33. }).Write(buf, protocol.VersionWhatever)).To(Succeed())
  34. return buf.Bytes()
  35. }
  36. getPacket := func(connID protocol.ConnectionID) []byte {
  37. return getPacketWithLength(connID, 2)
  38. }
  39. BeforeEach(func() {
  40. statelessResetKey = nil
  41. connIDLen = 0
  42. })
  43. JustBeforeEach(func() {
  44. conn = newMockPacketConn()
  45. handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, utils.DefaultLogger).(*packetHandlerMap)
  46. })
  47. AfterEach(func() {
  48. // delete sessions and the server before closing
  49. // They might be mock implementations, and we'd have to register the expected calls before otherwise.
  50. handler.mutex.Lock()
  51. for connID := range handler.handlers {
  52. delete(handler.handlers, connID)
  53. }
  54. handler.server = nil
  55. handler.mutex.Unlock()
  56. handler.Close()
  57. Eventually(handler.listening).Should(BeClosed())
  58. })
  59. It("closes", func() {
  60. getMultiplexer() // make the sync.Once execute
  61. // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
  62. mockMultiplexer := NewMockMultiplexer(mockCtrl)
  63. origMultiplexer := connMuxer
  64. connMuxer = mockMultiplexer
  65. defer func() {
  66. connMuxer = origMultiplexer
  67. }()
  68. testErr := errors.New("test error ")
  69. sess1 := NewMockPacketHandler(mockCtrl)
  70. sess1.EXPECT().destroy(testErr)
  71. sess2 := NewMockPacketHandler(mockCtrl)
  72. sess2.EXPECT().destroy(testErr)
  73. handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1)
  74. handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2)
  75. mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
  76. handler.close(testErr)
  77. })
  78. Context("handling packets", func() {
  79. BeforeEach(func() {
  80. connIDLen = 5
  81. })
  82. It("handles packets for different packet handlers on the same packet conn", func() {
  83. connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
  84. connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
  85. packetHandler1 := NewMockPacketHandler(mockCtrl)
  86. packetHandler2 := NewMockPacketHandler(mockCtrl)
  87. handledPacket1 := make(chan struct{})
  88. handledPacket2 := make(chan struct{})
  89. packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
  90. connID, err := wire.ParseConnectionID(p.data, 0)
  91. Expect(err).ToNot(HaveOccurred())
  92. Expect(connID).To(Equal(connID1))
  93. close(handledPacket1)
  94. })
  95. packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
  96. connID, err := wire.ParseConnectionID(p.data, 0)
  97. Expect(err).ToNot(HaveOccurred())
  98. Expect(connID).To(Equal(connID2))
  99. close(handledPacket2)
  100. })
  101. handler.Add(connID1, packetHandler1)
  102. handler.Add(connID2, packetHandler2)
  103. conn.dataToRead <- getPacket(connID1)
  104. conn.dataToRead <- getPacket(connID2)
  105. Eventually(handledPacket1).Should(BeClosed())
  106. Eventually(handledPacket2).Should(BeClosed())
  107. })
  108. It("drops unparseable packets", func() {
  109. handler.handlePacket(nil, nil, []byte{0, 1, 2, 3})
  110. })
  111. It("deletes removed sessions immediately", func() {
  112. handler.deleteRetiredSessionsAfter = time.Hour
  113. connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
  114. handler.Add(connID, NewMockPacketHandler(mockCtrl))
  115. handler.Remove(connID)
  116. handler.handlePacket(nil, nil, getPacket(connID))
  117. // don't EXPECT any calls to handlePacket of the MockPacketHandler
  118. })
  119. It("deletes retired session entries after a wait time", func() {
  120. handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
  121. connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
  122. handler.Add(connID, NewMockPacketHandler(mockCtrl))
  123. handler.Retire(connID)
  124. time.Sleep(scaleDuration(30 * time.Millisecond))
  125. handler.handlePacket(nil, nil, getPacket(connID))
  126. // don't EXPECT any calls to handlePacket of the MockPacketHandler
  127. })
  128. It("passes packets arriving late for closed sessions to that session", func() {
  129. handler.deleteRetiredSessionsAfter = time.Hour
  130. connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
  131. packetHandler := NewMockPacketHandler(mockCtrl)
  132. handled := make(chan struct{})
  133. packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
  134. close(handled)
  135. })
  136. handler.Add(connID, packetHandler)
  137. handler.Retire(connID)
  138. handler.handlePacket(nil, nil, getPacket(connID))
  139. Eventually(handled).Should(BeClosed())
  140. })
  141. It("drops packets for unknown receivers", func() {
  142. connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
  143. handler.handlePacket(nil, nil, getPacket(connID))
  144. })
  145. It("closes the packet handlers when reading from the conn fails", func() {
  146. done := make(chan struct{})
  147. packetHandler := NewMockPacketHandler(mockCtrl)
  148. packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
  149. Expect(e).To(HaveOccurred())
  150. close(done)
  151. })
  152. handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
  153. conn.Close()
  154. Eventually(done).Should(BeClosed())
  155. })
  156. })
  157. Context("running a server", func() {
  158. It("adds a server", func() {
  159. connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
  160. p := getPacket(connID)
  161. server := NewMockUnknownPacketHandler(mockCtrl)
  162. server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
  163. cid, err := wire.ParseConnectionID(p.data, 0)
  164. Expect(err).ToNot(HaveOccurred())
  165. Expect(cid).To(Equal(connID))
  166. })
  167. handler.SetServer(server)
  168. handler.handlePacket(nil, nil, p)
  169. })
  170. It("closes all server sessions", func() {
  171. clientSess := NewMockPacketHandler(mockCtrl)
  172. clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
  173. serverSess := NewMockPacketHandler(mockCtrl)
  174. serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
  175. serverSess.EXPECT().Close()
  176. handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
  177. handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
  178. handler.CloseServer()
  179. })
  180. It("stops handling packets with unknown connection IDs after the server is closed", func() {
  181. connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
  182. p := getPacket(connID)
  183. server := NewMockUnknownPacketHandler(mockCtrl)
  184. // don't EXPECT any calls to server.handlePacket
  185. handler.SetServer(server)
  186. handler.CloseServer()
  187. handler.handlePacket(nil, nil, p)
  188. })
  189. })
  190. Context("stateless resets", func() {
  191. BeforeEach(func() {
  192. connIDLen = 5
  193. })
  194. Context("handling", func() {
  195. It("handles stateless resets", func() {
  196. packetHandler := NewMockPacketHandler(mockCtrl)
  197. token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
  198. handler.AddResetToken(token, packetHandler)
  199. packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
  200. packet = append(packet, token[:]...)
  201. destroyed := make(chan struct{})
  202. packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) {
  203. close(destroyed)
  204. })
  205. conn.dataToRead <- packet
  206. Eventually(destroyed).Should(BeClosed())
  207. })
  208. It("handles stateless resets for 0-length connection IDs", func() {
  209. handler.connIDLen = 0
  210. packetHandler := NewMockPacketHandler(mockCtrl)
  211. token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
  212. handler.AddResetToken(token, packetHandler)
  213. packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
  214. packet = append(packet, token[:]...)
  215. destroyed := make(chan struct{})
  216. packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) {
  217. close(destroyed)
  218. })
  219. conn.dataToRead <- packet
  220. Eventually(destroyed).Should(BeClosed())
  221. })
  222. It("deletes reset tokens", func() {
  223. handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
  224. connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
  225. packetHandler := NewMockPacketHandler(mockCtrl)
  226. handler.Add(connID, packetHandler)
  227. token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
  228. handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
  229. handler.RemoveResetToken(token)
  230. packetHandler.EXPECT().handlePacket(gomock.Any())
  231. p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
  232. p = append(p, make([]byte, 50)...)
  233. p = append(p, token[:]...)
  234. handler.handlePacket(nil, nil, p)
  235. // destroy() would be called from a separate go routine
  236. // make sure we give it enough time to be called to cause an error here
  237. time.Sleep(scaleDuration(25 * time.Millisecond))
  238. })
  239. })
  240. Context("generating", func() {
  241. BeforeEach(func() {
  242. key := make([]byte, 32)
  243. rand.Read(key)
  244. statelessResetKey = key
  245. })
  246. It("generates stateless reset tokens", func() {
  247. connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
  248. connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
  249. token1 := handler.GetStatelessResetToken(connID1)
  250. Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1))
  251. Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1))
  252. })
  253. It("sends stateless resets", func() {
  254. addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
  255. p := append([]byte{40}, make([]byte, 100)...)
  256. handler.handlePacket(addr, getPacketBuffer(), p)
  257. var reset mockPacketConnWrite
  258. Eventually(conn.dataWritten).Should(Receive(&reset))
  259. Expect(reset.to).To(Equal(addr))
  260. Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet
  261. Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize))
  262. })
  263. It("doesn't send stateless resets for small packets", func() {
  264. addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
  265. p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
  266. handler.handlePacket(addr, getPacketBuffer(), p)
  267. Consistently(conn.dataWritten).ShouldNot(Receive())
  268. })
  269. })
  270. Context("if no key is configured", func() {
  271. It("doesn't send stateless resets", func() {
  272. addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
  273. p := append([]byte{40}, make([]byte, 100)...)
  274. handler.handlePacket(addr, getPacketBuffer(), p)
  275. Consistently(conn.dataWritten).ShouldNot(Receive())
  276. })
  277. })
  278. })
  279. })