packet_handler_map.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. package quic
  2. import (
  3. "crypto/hmac"
  4. "crypto/rand"
  5. "crypto/sha256"
  6. "errors"
  7. "hash"
  8. "net"
  9. "sync"
  10. "time"
  11. "github.com/lucas-clemente/quic-go/internal/protocol"
  12. "github.com/lucas-clemente/quic-go/internal/utils"
  13. "github.com/lucas-clemente/quic-go/internal/wire"
  14. )
  15. // The packetHandlerMap stores packetHandlers, identified by connection ID.
  16. // It is used:
  17. // * by the server to store sessions
  18. // * when multiplexing outgoing connections to store clients
  19. type packetHandlerMap struct {
  20. mutex sync.RWMutex
  21. conn net.PacketConn
  22. connIDLen int
  23. handlers map[string] /* string(ConnectionID)*/ packetHandler
  24. resetTokens map[[16]byte] /* stateless reset token */ packetHandler
  25. server unknownPacketHandler
  26. listening chan struct{} // is closed when listen returns
  27. closed bool
  28. deleteRetiredSessionsAfter time.Duration
  29. statelessResetEnabled bool
  30. statelessResetHasher hash.Hash
  31. logger utils.Logger
  32. }
  33. var _ packetHandlerManager = &packetHandlerMap{}
  34. func newPacketHandlerMap(
  35. conn net.PacketConn,
  36. connIDLen int,
  37. statelessResetKey []byte,
  38. logger utils.Logger,
  39. ) packetHandlerManager {
  40. m := &packetHandlerMap{
  41. conn: conn,
  42. connIDLen: connIDLen,
  43. listening: make(chan struct{}),
  44. handlers: make(map[string]packetHandler),
  45. resetTokens: make(map[[16]byte]packetHandler),
  46. deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
  47. statelessResetEnabled: len(statelessResetKey) > 0,
  48. statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
  49. logger: logger,
  50. }
  51. go m.listen()
  52. return m
  53. }
  54. func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
  55. h.mutex.Lock()
  56. h.handlers[string(id)] = handler
  57. h.mutex.Unlock()
  58. }
  59. func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
  60. h.removeByConnectionIDAsString(string(id))
  61. }
  62. func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
  63. h.mutex.Lock()
  64. delete(h.handlers, id)
  65. h.mutex.Unlock()
  66. }
  67. func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
  68. h.retireByConnectionIDAsString(string(id))
  69. }
  70. func (h *packetHandlerMap) retireByConnectionIDAsString(id string) {
  71. time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
  72. h.removeByConnectionIDAsString(id)
  73. })
  74. }
  75. func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) {
  76. h.mutex.Lock()
  77. h.resetTokens[token] = handler
  78. h.mutex.Unlock()
  79. }
  80. func (h *packetHandlerMap) RemoveResetToken(token [16]byte) {
  81. h.mutex.Lock()
  82. delete(h.resetTokens, token)
  83. h.mutex.Unlock()
  84. }
  85. func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
  86. h.mutex.Lock()
  87. h.server = s
  88. h.mutex.Unlock()
  89. }
  90. func (h *packetHandlerMap) CloseServer() {
  91. h.mutex.Lock()
  92. h.server = nil
  93. var wg sync.WaitGroup
  94. for id, handler := range h.handlers {
  95. if handler.getPerspective() == protocol.PerspectiveServer {
  96. wg.Add(1)
  97. go func(id string, handler packetHandler) {
  98. // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
  99. _ = handler.Close()
  100. h.retireByConnectionIDAsString(id)
  101. wg.Done()
  102. }(id, handler)
  103. }
  104. }
  105. h.mutex.Unlock()
  106. wg.Wait()
  107. }
  108. // Close the underlying connection and wait until listen() has returned.
  109. func (h *packetHandlerMap) Close() error {
  110. if err := h.conn.Close(); err != nil {
  111. return err
  112. }
  113. <-h.listening // wait until listening returns
  114. return nil
  115. }
  116. func (h *packetHandlerMap) close(e error) error {
  117. h.mutex.Lock()
  118. if h.closed {
  119. h.mutex.Unlock()
  120. return nil
  121. }
  122. h.closed = true
  123. var wg sync.WaitGroup
  124. for _, handler := range h.handlers {
  125. wg.Add(1)
  126. go func(handler packetHandler) {
  127. handler.destroy(e)
  128. wg.Done()
  129. }(handler)
  130. }
  131. if h.server != nil {
  132. h.server.closeWithError(e)
  133. }
  134. h.mutex.Unlock()
  135. wg.Wait()
  136. return getMultiplexer().RemoveConn(h.conn)
  137. }
  138. func (h *packetHandlerMap) listen() {
  139. defer close(h.listening)
  140. for {
  141. buffer := getPacketBuffer()
  142. data := buffer.Slice
  143. // The packet size should not exceed protocol.MaxReceivePacketSize bytes
  144. // If it does, we only read a truncated packet, which will then end up undecryptable
  145. n, addr, err := h.conn.ReadFrom(data)
  146. if err != nil {
  147. h.close(err)
  148. return
  149. }
  150. h.handlePacket(addr, buffer, data[:n])
  151. }
  152. }
  153. func (h *packetHandlerMap) handlePacket(
  154. addr net.Addr,
  155. buffer *packetBuffer,
  156. data []byte,
  157. ) {
  158. connID, err := wire.ParseConnectionID(data, h.connIDLen)
  159. if err != nil {
  160. h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err)
  161. return
  162. }
  163. rcvTime := time.Now()
  164. h.mutex.RLock()
  165. defer h.mutex.RUnlock()
  166. if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset {
  167. return
  168. }
  169. handler, handlerFound := h.handlers[string(connID)]
  170. p := &receivedPacket{
  171. remoteAddr: addr,
  172. rcvTime: rcvTime,
  173. buffer: buffer,
  174. data: data,
  175. }
  176. if handlerFound { // existing session
  177. handler.handlePacket(p)
  178. return
  179. }
  180. if data[0]&0x80 == 0 {
  181. go h.maybeSendStatelessReset(p, connID)
  182. return
  183. }
  184. if h.server == nil { // no server set
  185. h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
  186. return
  187. }
  188. h.server.handlePacket(p)
  189. }
  190. func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
  191. // stateless resets are always short header packets
  192. if data[0]&0x80 != 0 {
  193. return false
  194. }
  195. if len(data) < protocol.MinStatelessResetSize {
  196. return false
  197. }
  198. var token [16]byte
  199. copy(token[:], data[len(data)-16:])
  200. if sess, ok := h.resetTokens[token]; ok {
  201. h.logger.Debugf("Received a stateless retry with token %#x. Closing session.", token)
  202. go sess.destroy(errors.New("received a stateless reset"))
  203. return true
  204. }
  205. return false
  206. }
  207. func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte {
  208. var token [16]byte
  209. if !h.statelessResetEnabled {
  210. // Return a random stateless reset token.
  211. // This token will be sent in the server's transport parameters.
  212. // By using a random token, an off-path attacker won't be able to disrupt the connection.
  213. rand.Read(token[:])
  214. return token
  215. }
  216. h.statelessResetHasher.Write(connID.Bytes())
  217. copy(token[:], h.statelessResetHasher.Sum(nil))
  218. h.statelessResetHasher.Reset()
  219. return token
  220. }
  221. func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
  222. defer p.buffer.Release()
  223. if !h.statelessResetEnabled {
  224. return
  225. }
  226. // Don't send a stateless reset in response to very small packets.
  227. // This includes packets that could be stateless resets.
  228. if len(p.data) <= protocol.MinStatelessResetSize {
  229. return
  230. }
  231. token := h.GetStatelessResetToken(connID)
  232. h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
  233. data := make([]byte, 23)
  234. rand.Read(data)
  235. data[0] = (data[0] & 0x7f) | 0x40
  236. data = append(data, token[:]...)
  237. if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil {
  238. h.logger.Debugf("Error sending Stateless Reset: %s", err)
  239. }
  240. }