packet_packer.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. package quic
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "time"
  8. "github.com/lucas-clemente/quic-go/internal/ackhandler"
  9. "github.com/lucas-clemente/quic-go/internal/handshake"
  10. "github.com/lucas-clemente/quic-go/internal/protocol"
  11. "github.com/lucas-clemente/quic-go/internal/utils"
  12. "github.com/lucas-clemente/quic-go/internal/wire"
  13. )
  14. type packer interface {
  15. PackPacket() (*packedPacket, error)
  16. MaybePackAckPacket() (*packedPacket, error)
  17. PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error)
  18. PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
  19. HandleTransportParameters(*handshake.TransportParameters)
  20. SetToken([]byte)
  21. ChangeDestConnectionID(protocol.ConnectionID)
  22. }
  23. type packedPacket struct {
  24. header *wire.ExtendedHeader
  25. raw []byte
  26. frames []wire.Frame
  27. buffer *packetBuffer
  28. }
  29. func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
  30. if !p.header.IsLongHeader {
  31. return protocol.Encryption1RTT
  32. }
  33. switch p.header.Type {
  34. case protocol.PacketTypeInitial:
  35. return protocol.EncryptionInitial
  36. case protocol.PacketTypeHandshake:
  37. return protocol.EncryptionHandshake
  38. default:
  39. return protocol.EncryptionUnspecified
  40. }
  41. }
  42. func (p *packedPacket) IsRetransmittable() bool {
  43. return ackhandler.HasRetransmittableFrames(p.frames)
  44. }
  45. func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
  46. return &ackhandler.Packet{
  47. PacketNumber: p.header.PacketNumber,
  48. PacketType: p.header.Type,
  49. Frames: p.frames,
  50. Length: protocol.ByteCount(len(p.raw)),
  51. EncryptionLevel: p.EncryptionLevel(),
  52. SendTime: time.Now(),
  53. }
  54. }
  55. func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
  56. maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
  57. // If this is not a UDP address, we don't know anything about the MTU.
  58. // Use the minimum size of an Initial packet as the max packet size.
  59. if udpAddr, ok := addr.(*net.UDPAddr); ok {
  60. // If ip is not an IPv4 address, To4 returns nil.
  61. // Note that there might be some corner cases, where this is not correct.
  62. // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
  63. if udpAddr.IP.To4() == nil {
  64. maxSize = protocol.MaxPacketSizeIPv6
  65. } else {
  66. maxSize = protocol.MaxPacketSizeIPv4
  67. }
  68. }
  69. return maxSize
  70. }
  71. type packetNumberManager interface {
  72. PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
  73. PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
  74. }
  75. type sealingManager interface {
  76. GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
  77. GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
  78. }
  79. type frameSource interface {
  80. AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
  81. AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
  82. }
  83. type ackFrameSource interface {
  84. GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame
  85. }
  86. type packetPacker struct {
  87. destConnID protocol.ConnectionID
  88. srcConnID protocol.ConnectionID
  89. perspective protocol.Perspective
  90. version protocol.VersionNumber
  91. cryptoSetup sealingManager
  92. initialStream cryptoStream
  93. handshakeStream cryptoStream
  94. token []byte
  95. pnManager packetNumberManager
  96. framer frameSource
  97. acks ackFrameSource
  98. maxPacketSize protocol.ByteCount
  99. numNonRetransmittableAcks int
  100. }
  101. var _ packer = &packetPacker{}
  102. func newPacketPacker(
  103. destConnID protocol.ConnectionID,
  104. srcConnID protocol.ConnectionID,
  105. initialStream cryptoStream,
  106. handshakeStream cryptoStream,
  107. packetNumberManager packetNumberManager,
  108. remoteAddr net.Addr, // only used for determining the max packet size
  109. cryptoSetup sealingManager,
  110. framer frameSource,
  111. acks ackFrameSource,
  112. perspective protocol.Perspective,
  113. version protocol.VersionNumber,
  114. ) *packetPacker {
  115. return &packetPacker{
  116. cryptoSetup: cryptoSetup,
  117. destConnID: destConnID,
  118. srcConnID: srcConnID,
  119. initialStream: initialStream,
  120. handshakeStream: handshakeStream,
  121. perspective: perspective,
  122. version: version,
  123. framer: framer,
  124. acks: acks,
  125. pnManager: packetNumberManager,
  126. maxPacketSize: getMaxPacketSize(remoteAddr),
  127. }
  128. }
  129. // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
  130. func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
  131. frames := []wire.Frame{ccf}
  132. encLevel, sealer := p.cryptoSetup.GetSealer()
  133. header := p.getHeader(encLevel)
  134. return p.writeAndSealPacket(header, frames, encLevel, sealer)
  135. }
  136. func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
  137. ack := p.acks.GetAckFrame(protocol.Encryption1RTT)
  138. if ack == nil {
  139. return nil, nil
  140. }
  141. // TODO(#1534): only pack ACKs with the right encryption level
  142. encLevel, sealer := p.cryptoSetup.GetSealer()
  143. header := p.getHeader(encLevel)
  144. frames := []wire.Frame{ack}
  145. return p.writeAndSealPacket(header, frames, encLevel, sealer)
  146. }
  147. // PackRetransmission packs a retransmission
  148. // For packets sent after completion of the handshake, it might happen that 2 packets have to be sent.
  149. // This can happen e.g. when a longer packet number is used in the header.
  150. func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) {
  151. var controlFrames []wire.Frame
  152. var streamFrames []*wire.StreamFrame
  153. for _, f := range packet.Frames {
  154. // CRYPTO frames are treated as control frames here.
  155. // Since we're making sure that the header can never be larger for a retransmission,
  156. // we never have to split CRYPTO frames.
  157. if sf, ok := f.(*wire.StreamFrame); ok {
  158. sf.DataLenPresent = true
  159. streamFrames = append(streamFrames, sf)
  160. } else {
  161. controlFrames = append(controlFrames, f)
  162. }
  163. }
  164. var packets []*packedPacket
  165. encLevel := packet.EncryptionLevel
  166. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
  167. if err != nil {
  168. return nil, err
  169. }
  170. for len(controlFrames) > 0 || len(streamFrames) > 0 {
  171. var frames []wire.Frame
  172. var length protocol.ByteCount
  173. header := p.getHeader(encLevel)
  174. headerLen := header.GetLength(p.version)
  175. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
  176. for len(controlFrames) > 0 {
  177. frame := controlFrames[0]
  178. frameLen := frame.Length(p.version)
  179. if length+frameLen > maxSize {
  180. break
  181. }
  182. length += frameLen
  183. frames = append(frames, frame)
  184. controlFrames = controlFrames[1:]
  185. }
  186. for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize {
  187. frame := streamFrames[0]
  188. frame.DataLenPresent = false
  189. frameToAdd := frame
  190. sf, err := frame.MaybeSplitOffFrame(maxSize-length, p.version)
  191. if err != nil {
  192. return nil, err
  193. }
  194. if sf != nil {
  195. frameToAdd = sf
  196. } else {
  197. streamFrames = streamFrames[1:]
  198. }
  199. frame.DataLenPresent = true
  200. length += frameToAdd.Length(p.version)
  201. frames = append(frames, frameToAdd)
  202. }
  203. if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
  204. sf.DataLenPresent = false
  205. }
  206. p, err := p.writeAndSealPacket(header, frames, encLevel, sealer)
  207. if err != nil {
  208. return nil, err
  209. }
  210. packets = append(packets, p)
  211. }
  212. return packets, nil
  213. }
  214. // PackPacket packs a new packet
  215. // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
  216. func (p *packetPacker) PackPacket() (*packedPacket, error) {
  217. packet, err := p.maybePackCryptoPacket()
  218. if err != nil {
  219. return nil, err
  220. }
  221. if packet != nil {
  222. return packet, nil
  223. }
  224. encLevel, sealer := p.cryptoSetup.GetSealer()
  225. header := p.getHeader(encLevel)
  226. headerLen := header.GetLength(p.version)
  227. if err != nil {
  228. return nil, err
  229. }
  230. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
  231. frames, err := p.composeNextPacket(maxSize)
  232. if err != nil {
  233. return nil, err
  234. }
  235. // Check if we have enough frames to send
  236. if len(frames) == 0 {
  237. return nil, nil
  238. }
  239. // check if this packet only contains an ACK
  240. if !ackhandler.HasRetransmittableFrames(frames) {
  241. if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks {
  242. frames = append(frames, &wire.PingFrame{})
  243. p.numNonRetransmittableAcks = 0
  244. } else {
  245. p.numNonRetransmittableAcks++
  246. }
  247. } else {
  248. p.numNonRetransmittableAcks = 0
  249. }
  250. return p.writeAndSealPacket(header, frames, encLevel, sealer)
  251. }
  252. func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
  253. var s cryptoStream
  254. var encLevel protocol.EncryptionLevel
  255. hasData := p.initialStream.HasData()
  256. ack := p.acks.GetAckFrame(protocol.EncryptionInitial)
  257. if hasData || ack != nil {
  258. s = p.initialStream
  259. encLevel = protocol.EncryptionInitial
  260. } else {
  261. hasData = p.handshakeStream.HasData()
  262. ack = p.acks.GetAckFrame(protocol.EncryptionHandshake)
  263. if hasData || ack != nil {
  264. s = p.handshakeStream
  265. encLevel = protocol.EncryptionHandshake
  266. }
  267. }
  268. if s == nil {
  269. return nil, nil
  270. }
  271. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
  272. if err != nil {
  273. // The sealer
  274. return nil, err
  275. }
  276. hdr := p.getHeader(encLevel)
  277. hdrLen := hdr.GetLength(p.version)
  278. var length protocol.ByteCount
  279. frames := make([]wire.Frame, 0, 2)
  280. if ack != nil {
  281. frames = append(frames, ack)
  282. length += ack.Length(p.version)
  283. }
  284. if hasData {
  285. cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
  286. frames = append(frames, cf)
  287. }
  288. return p.writeAndSealPacket(hdr, frames, encLevel, sealer)
  289. }
  290. func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) {
  291. var length protocol.ByteCount
  292. var frames []wire.Frame
  293. // ACKs need to go first, so that the sentPacketHandler will recognize them
  294. if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil {
  295. frames = append(frames, ack)
  296. length += ack.Length(p.version)
  297. }
  298. var lengthAdded protocol.ByteCount
  299. frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length)
  300. length += lengthAdded
  301. // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field
  302. // this leads to a properly sized packet in all cases, since we do all the packet length calculations with STREAM frames that have the DataLen set
  303. // however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size
  304. // the length is encoded to either 1 or 2 bytes
  305. maxFrameSize++
  306. frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length)
  307. if len(frames) > 0 {
  308. lastFrame := frames[len(frames)-1]
  309. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  310. sf.DataLenPresent = false
  311. }
  312. }
  313. return frames, nil
  314. }
  315. func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
  316. pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
  317. header := &wire.ExtendedHeader{}
  318. header.PacketNumber = pn
  319. header.PacketNumberLen = pnLen
  320. header.Version = p.version
  321. header.DestConnectionID = p.destConnID
  322. if encLevel != protocol.Encryption1RTT {
  323. header.IsLongHeader = true
  324. // Always send Initial and Handshake packets with the maximum packet number length.
  325. // This simplifies retransmissions: Since the header can't get any larger,
  326. // we don't need to split CRYPTO frames.
  327. header.PacketNumberLen = protocol.PacketNumberLen4
  328. header.SrcConnectionID = p.srcConnID
  329. // Set the length to the maximum packet size.
  330. // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
  331. header.Length = p.maxPacketSize
  332. switch encLevel {
  333. case protocol.EncryptionInitial:
  334. header.Type = protocol.PacketTypeInitial
  335. case protocol.EncryptionHandshake:
  336. header.Type = protocol.PacketTypeHandshake
  337. }
  338. }
  339. return header
  340. }
  341. func (p *packetPacker) writeAndSealPacket(
  342. header *wire.ExtendedHeader,
  343. frames []wire.Frame,
  344. encLevel protocol.EncryptionLevel,
  345. sealer handshake.Sealer,
  346. ) (*packedPacket, error) {
  347. packetBuffer := getPacketBuffer()
  348. buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
  349. addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
  350. if header.IsLongHeader {
  351. if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
  352. header.Token = p.token
  353. }
  354. if addPaddingForInitial {
  355. headerLen := header.GetLength(p.version)
  356. header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen
  357. } else {
  358. // long header packets always use 4 byte packet number, so we never need to pad short payloads
  359. length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen)
  360. for _, frame := range frames {
  361. length += frame.Length(p.version)
  362. }
  363. header.Length = length
  364. }
  365. }
  366. if err := header.Write(buffer, p.version); err != nil {
  367. return nil, err
  368. }
  369. payloadOffset := buffer.Len()
  370. // write all frames but the last one
  371. for _, frame := range frames[:len(frames)-1] {
  372. if err := frame.Write(buffer, p.version); err != nil {
  373. return nil, err
  374. }
  375. }
  376. lastFrame := frames[len(frames)-1]
  377. if addPaddingForInitial {
  378. // when appending padding, we need to make sure that the last STREAM frames has the data length set
  379. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  380. sf.DataLenPresent = true
  381. }
  382. } else {
  383. payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version))
  384. if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 {
  385. // Pad the packet such that packet number length + payload length is 4 bytes.
  386. // This is needed to enable the peer to get a 16 byte sample for header protection.
  387. buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
  388. }
  389. }
  390. if err := lastFrame.Write(buffer, p.version); err != nil {
  391. return nil, err
  392. }
  393. if addPaddingForInitial {
  394. paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
  395. if paddingLen > 0 {
  396. buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
  397. }
  398. }
  399. if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
  400. return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
  401. }
  402. raw := buffer.Bytes()
  403. _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
  404. raw = raw[0 : buffer.Len()+sealer.Overhead()]
  405. pnOffset := payloadOffset - int(header.PacketNumberLen)
  406. sealer.EncryptHeader(
  407. raw[pnOffset+4:pnOffset+4+16],
  408. &raw[0],
  409. raw[pnOffset:payloadOffset],
  410. )
  411. num := p.pnManager.PopPacketNumber(encLevel)
  412. if num != header.PacketNumber {
  413. return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
  414. }
  415. return &packedPacket{
  416. header: header,
  417. raw: raw,
  418. frames: frames,
  419. buffer: packetBuffer,
  420. }, nil
  421. }
  422. func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
  423. p.destConnID = connID
  424. }
  425. func (p *packetPacker) SetToken(token []byte) {
  426. p.token = token
  427. }
  428. func (p *packetPacker) HandleTransportParameters(params *handshake.TransportParameters) {
  429. if params.MaxPacketSize != 0 {
  430. p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxPacketSize)
  431. }
  432. }