webtransport.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. package webtransport
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "fmt"
  7. "log"
  8. "net/http"
  9. "net/url"
  10. "github.com/marten-seemann/qpack"
  11. "github.com/quic-go/quic-go"
  12. "github.com/quic-go/quic-go/http3"
  13. "github.com/quic-go/quic-go/quicvarint"
  14. h3 "m7s.live/plugin/webtransport/v4/internal"
  15. )
  16. type receiveMessageResult struct {
  17. msg []byte
  18. err error
  19. }
  20. // A CertFile represents a TLS certificate or key, expressed either as a file path or as the certificate/key itself as a []byte.
  21. type CertFile struct {
  22. Path string
  23. Data []byte
  24. }
  25. // Wrapper for quic.Config
  26. type QuicConfig quic.Config
  27. // A Server defines parameters for running a WebTransport server. Use http.HandleFunc to register HTTP/3 endpoints for handling WebTransport requests.
  28. type Server struct {
  29. http.Handler
  30. // ListenAddr sets an address to bind server to, e.g. ":4433"
  31. ListenAddr string
  32. // TLSCert defines a path to, or byte array containing, a certificate (CRT file)
  33. TLSCert CertFile
  34. // TLSKey defines a path to, or byte array containing, the certificate's private key (KEY file)
  35. TLSKey CertFile
  36. // AllowedOrigins represents list of allowed origins to connect from
  37. AllowedOrigins []string
  38. // Additional configuration parameters to pass onto QUIC listener
  39. QuicConfig *QuicConfig
  40. }
  41. // Starts a WebTransport server and blocks while it's running. Cancel the supplied Context to stop the server.
  42. func (s *Server) Run(ctx context.Context) error {
  43. if s.Handler == nil {
  44. s.Handler = http.DefaultServeMux
  45. }
  46. if s.QuicConfig == nil {
  47. s.QuicConfig = &QuicConfig{}
  48. }
  49. s.QuicConfig.EnableDatagrams = true
  50. listener, err := quic.ListenAddr(s.ListenAddr, s.generateTLSConfig(), (*quic.Config)(s.QuicConfig))
  51. if err != nil {
  52. return err
  53. }
  54. go func() {
  55. <-ctx.Done()
  56. listener.Close()
  57. }()
  58. for {
  59. sess, err := listener.Accept(ctx)
  60. if err != nil {
  61. return err
  62. }
  63. go s.handleSession(ctx, sess)
  64. }
  65. }
  66. func (s *Server) handleSession(ctx context.Context, sess quic.Connection) {
  67. serverControlStream, err := sess.OpenUniStream()
  68. if err != nil {
  69. return
  70. }
  71. // Write server settings
  72. streamHeader := h3.StreamHeader{Type: h3.STREAM_CONTROL}
  73. streamHeader.Write(serverControlStream)
  74. settingsFrame := (h3.SettingsMap{h3.H3_DATAGRAM_05: 1, h3.ENABLE_WEBTRANSPORT: 1}).ToFrame()
  75. settingsFrame.Write(serverControlStream)
  76. // Accept control stream - client settings will appear here
  77. clientControlStream, err := sess.AcceptUniStream(context.Background())
  78. if err != nil {
  79. log.Println(err)
  80. return
  81. }
  82. // log.Printf("Read settings from control stream id: %d\n", stream.StreamID())
  83. clientSettingsReader := quicvarint.NewReader(clientControlStream)
  84. quicvarint.Read(clientSettingsReader)
  85. clientSettingsFrame := h3.Frame{}
  86. if clientSettingsFrame.Read(clientControlStream); err != nil || clientSettingsFrame.Type != h3.FRAME_SETTINGS {
  87. // log.Println("control stream read error, or not a settings frame")
  88. return
  89. }
  90. // Accept request stream
  91. requestStream, err := sess.AcceptStream(ctx)
  92. if err != nil {
  93. // log.Printf("request stream err: %v", err)
  94. return
  95. }
  96. // log.Printf("request stream accepted: %d", requestStream.StreamID())
  97. ctx, cancelFunction := context.WithCancel(requestStream.Context())
  98. ctx = context.WithValue(ctx, http3.ServerContextKey, s)
  99. ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr())
  100. // log.Println(streamType, settingsFrame)
  101. headersFrame := h3.Frame{}
  102. err = headersFrame.Read(requestStream)
  103. if err != nil {
  104. // log.Printf("request stream ParseNextFrame err: %v", err)
  105. cancelFunction()
  106. requestStream.Close()
  107. return
  108. }
  109. if headersFrame.Type != h3.FRAME_HEADERS {
  110. // log.Println("request stream got not HeadersFrame")
  111. cancelFunction()
  112. requestStream.Close()
  113. return
  114. }
  115. decoder := qpack.NewDecoder(nil)
  116. hfs, err := decoder.DecodeFull(headersFrame.Data)
  117. if err != nil {
  118. // log.Printf("request stream decoder err: %v", err)
  119. cancelFunction()
  120. requestStream.Close()
  121. return
  122. }
  123. req, protocol, err := h3.RequestFromHeaders(hfs)
  124. if err != nil {
  125. cancelFunction()
  126. requestStream.Close()
  127. return
  128. }
  129. req.RemoteAddr = sess.RemoteAddr().String()
  130. req = req.WithContext(ctx)
  131. rw := h3.NewResponseWriter(requestStream)
  132. rw.Header().Add("sec-webtransport-http3-draft", "draft02")
  133. req.Body = &Session{Stream: requestStream, Session: sess, ClientControlStream: clientControlStream, ServerControlStream: serverControlStream, responseWriter: rw, context: ctx, cancel: cancelFunction}
  134. if protocol != "webtransport" || !s.validateOrigin(req.Header.Get("origin")) {
  135. req.Body.(*Session).RejectSession(http.StatusBadRequest)
  136. return
  137. }
  138. // Drain request stream - this is so that we can catch the EOF and shut down cleanly when the client closes the transport
  139. go func() {
  140. for {
  141. buf := make([]byte, 1024)
  142. _, err := requestStream.Read(buf)
  143. if err != nil {
  144. cancelFunction()
  145. requestStream.Close()
  146. break
  147. }
  148. }
  149. }()
  150. s.ServeHTTP(rw, req)
  151. }
  152. func (s *Server) generateTLSConfig() *tls.Config {
  153. var cert tls.Certificate
  154. var err error
  155. if s.TLSCert.Path != "" && s.TLSKey.Path != "" {
  156. cert, err = tls.LoadX509KeyPair(s.TLSCert.Path, s.TLSKey.Path)
  157. } else {
  158. cert, err = tls.X509KeyPair(s.TLSCert.Data, s.TLSKey.Data)
  159. }
  160. if err != nil {
  161. log.Fatal(err)
  162. }
  163. return &tls.Config{
  164. Certificates: []tls.Certificate{cert},
  165. NextProtos: []string{"h3", "h3-32", "h3-31", "h3-30", "h3-29"},
  166. }
  167. }
  168. func (s *Server) validateOrigin(origin string) bool {
  169. // No origin specified - everything is allowed
  170. if s.AllowedOrigins == nil {
  171. return true
  172. }
  173. // Enforce allowed origins
  174. u, err := url.Parse(origin)
  175. if err != nil {
  176. return false
  177. }
  178. for _, b := range s.AllowedOrigins {
  179. if b == u.Host {
  180. return true
  181. }
  182. }
  183. return false
  184. }
  185. // ReceiveStream wraps a quic.ReceiveStream providing a unidirectional WebTransport client->server stream, including a Read function.
  186. type ReceiveStream struct {
  187. quic.ReceiveStream
  188. readHeaderBeforeData bool
  189. headerRead bool
  190. requestSessionID uint64
  191. }
  192. // SendStream wraps a quic.SendStream providing a unidirectional WebTransport server->client stream, including a Write function.
  193. type SendStream struct {
  194. quic.SendStream
  195. writeHeaderBeforeData bool
  196. headerWritten bool
  197. requestSessionID uint64
  198. }
  199. // Stream wraps a quic.Stream providing a bidirectional server<->client stream, including Read and Write functions.
  200. type WtStream quic.Stream
  201. // Read reads up to len(p) bytes from a WebTransport unidirectional stream, returning the actual number of bytes read.
  202. func (s *ReceiveStream) Read(p []byte) (int, error) {
  203. if s.readHeaderBeforeData && !s.headerRead {
  204. // Unidirectional stream - so we need to read stream header before first data read
  205. streamHeader := h3.StreamHeader{}
  206. if err := streamHeader.Read(s.ReceiveStream); err != nil {
  207. return 0, err
  208. }
  209. if streamHeader.Type != h3.STREAM_WEBTRANSPORT_UNI_STREAM {
  210. return 0, fmt.Errorf("unidirectional stream received with the wrong stream type")
  211. }
  212. s.requestSessionID = streamHeader.ID
  213. s.headerRead = true
  214. }
  215. return s.ReceiveStream.Read(p)
  216. }
  217. // Write writes up to len(p) bytes to a WebTransport unidirectional stream, returning the actual number of bytes written.
  218. func (s *SendStream) Write(p []byte) (int, error) {
  219. if s.writeHeaderBeforeData && !s.headerWritten {
  220. // Unidirectional stream - so we need to write stream header before first data write
  221. buf := quicvarint.Append(nil, h3.STREAM_WEBTRANSPORT_UNI_STREAM)
  222. buf = quicvarint.Append(buf, s.requestSessionID)
  223. if _, err := s.SendStream.Write(buf); err != nil {
  224. s.Close()
  225. return 0, err
  226. }
  227. s.headerWritten = true
  228. }
  229. return s.SendStream.Write(p)
  230. }
  231. // Session is a WebTransport session (and the Body of a WebTransport http.Request) wrapping the request stream (a quic.Stream), the two control streams and a quic.Session.
  232. type Session struct {
  233. quic.Stream
  234. Session quic.Connection
  235. ClientControlStream quic.ReceiveStream
  236. ServerControlStream quic.SendStream
  237. responseWriter *h3.ResponseWriter
  238. context context.Context
  239. cancel context.CancelFunc
  240. }
  241. // Context returns the context for the WebTransport session.
  242. func (s *Session) Context() context.Context {
  243. return s.context
  244. }
  245. // AcceptSession accepts an incoming WebTransport session. Call it in your http.HandleFunc.
  246. func (s *Session) AcceptSession() {
  247. r := s.responseWriter
  248. r.WriteHeader(http.StatusOK)
  249. r.Flush()
  250. }
  251. // AcceptSession rejects an incoming WebTransport session, returning the supplied HTML error code to the client. Call it in your http.HandleFunc.
  252. func (s *Session) RejectSession(errorCode int) {
  253. r := s.responseWriter
  254. r.WriteHeader(errorCode)
  255. r.Flush()
  256. s.CloseSession()
  257. }
  258. // ReceiveMessage returns a datagram received from a WebTransport session, blocking if necessary until one is available. Supply your own context, or use the WebTransport
  259. // session's Context() so that ending the WebTransport session automatically cancels this call. Note that datagrams are unreliable - depending on network conditions,
  260. // datagrams sent by the client may never be received by the server.
  261. func (s *Session) ReceiveMessage(ctx context.Context) ([]byte, error) {
  262. resultChannel := make(chan receiveMessageResult)
  263. go func() {
  264. msg, err := s.Session.ReceiveMessage(ctx)
  265. resultChannel <- receiveMessageResult{msg: msg, err: err}
  266. }()
  267. select {
  268. case result := <-resultChannel:
  269. if result.err != nil {
  270. return nil, result.err
  271. }
  272. datastream := bytes.NewReader(result.msg)
  273. quarterStreamId, err := quicvarint.Read(datastream)
  274. if err != nil {
  275. return nil, err
  276. }
  277. return result.msg[quicvarint.Len(quarterStreamId):], nil
  278. case <-ctx.Done():
  279. return nil, fmt.Errorf("WebTransport stream closed")
  280. }
  281. }
  282. // SendMessage sends a datagram over a WebTransport session. Supply your own context, or use the WebTransport
  283. // session's Context() so that ending the WebTransport session automatically cancels this call. Note that datagrams are unreliable - depending on network conditions,
  284. // datagrams sent by the server may never be received by the client.
  285. func (s *Session) SendMessage(msg []byte) error {
  286. // "Quarter Stream ID" (!) of associated request stream, as per https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram
  287. buf := quicvarint.Append(nil, uint64(s.StreamID()/4))
  288. return s.Session.SendMessage(append(buf, msg...))
  289. }
  290. // AcceptStream accepts an incoming (that is, client-initated) bidirectional stream, blocking if necessary until one is available. Supply your own context, or use the WebTransport
  291. // session's Context() so that ending the WebTransport session automatically cancels this call.
  292. func (s *Session) AcceptStream() (WtStream, error) {
  293. stream, err := s.Session.AcceptStream(s.context)
  294. if err != nil {
  295. return stream, err
  296. }
  297. streamFrame := h3.Frame{}
  298. err = streamFrame.Read(stream)
  299. return stream, err
  300. }
  301. // AcceptStream accepts an incoming (that is, client-initated) unidirectional stream, blocking if necessary until one is available. Supply your own context, or use the WebTransport
  302. // session's Context() so that ending the WebTransport session automatically cancels this call.
  303. func (s *Session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
  304. stream, err := s.Session.AcceptUniStream(ctx)
  305. return ReceiveStream{ReceiveStream: stream, readHeaderBeforeData: true, headerRead: false}, err
  306. }
  307. func (s *Session) internalOpenStream(ctx *context.Context, sync bool) (WtStream, error) {
  308. var stream quic.Stream
  309. var err error
  310. if sync {
  311. stream, err = s.Session.OpenStreamSync(*ctx)
  312. } else {
  313. stream, err = s.Session.OpenStream()
  314. }
  315. if err == nil {
  316. // Write frame header
  317. buf := quicvarint.Append(nil, h3.FRAME_WEBTRANSPORT_STREAM)
  318. buf = quicvarint.Append(buf, uint64(s.StreamID()))
  319. if _, err := stream.Write(buf); err != nil {
  320. stream.Close()
  321. }
  322. }
  323. return stream, err
  324. }
  325. func (s *Session) internalOpenUniStream(ctx *context.Context, sync bool) (SendStream, error) {
  326. var stream quic.SendStream
  327. var err error
  328. if sync {
  329. stream, err = s.Session.OpenUniStreamSync(*ctx)
  330. } else {
  331. stream, err = s.Session.OpenUniStream()
  332. }
  333. return SendStream{SendStream: stream, writeHeaderBeforeData: true, headerWritten: false, requestSessionID: uint64(s.StreamID())}, err
  334. }
  335. // OpenStream creates an outgoing (that is, server-initiated) bidirectional stream. It returns immediately.
  336. func (s *Session) OpenStream() (WtStream, error) {
  337. return s.internalOpenStream(nil, false)
  338. }
  339. // OpenStream creates an outgoing (that is, server-initiated) bidirectional stream. It generally returns immediately, but if the session's maximum number of streams
  340. // has been exceeded, it will block until a slot is available. Supply your own context, or use the WebTransport
  341. // session's Context() so that ending the WebTransport session automatically cancels this call.
  342. func (s *Session) OpenStreamSync(ctx context.Context) (WtStream, error) {
  343. return s.internalOpenStream(&ctx, true)
  344. }
  345. // OpenUniStream creates an outgoing (that is, server-initiated) bidirectional stream. It returns immediately.
  346. func (s *Session) OpenUniStream() (SendStream, error) {
  347. return s.internalOpenUniStream(nil, false)
  348. }
  349. // OpenUniStreamSync creates an outgoing (that is, server-initiated) unidirectional stream. It generally returns immediately, but if the session's maximum number of streams
  350. // has been exceeded, it will block until a slot is available. Supply your own context, or use the WebTransport
  351. // session's Context() so that ending the WebTransport session automatically cancels this call.
  352. func (s *Session) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
  353. return s.internalOpenUniStream(&ctx, true)
  354. }
  355. // CloseSession cleanly closes a WebTransport session. All active streams are cancelled before terminating the session.
  356. func (s *Session) CloseSession() {
  357. s.cancel()
  358. s.Close()
  359. }
  360. // CloseWithError closes a WebTransport session with a supplied error code and string.
  361. func (s *Session) CloseWithError(code quic.ApplicationErrorCode, str string) {
  362. s.Session.CloseWithError(code, str)
  363. }