netConnection.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package rtmp
  2. import (
  3. "bufio"
  4. "encoding/binary"
  5. "errors"
  6. "io"
  7. "net"
  8. "runtime"
  9. "sync/atomic"
  10. "go.uber.org/zap"
  11. "m7s.live/engine/v4/util"
  12. )
  13. const (
  14. SEND_CHUNK_SIZE_MESSAGE = "Send Chunk Size Message"
  15. SEND_ACK_MESSAGE = "Send Acknowledgement Message"
  16. SEND_ACK_WINDOW_SIZE_MESSAGE = "Send Window Acknowledgement Size Message"
  17. SEND_SET_PEER_BANDWIDTH_MESSAGE = "Send Set Peer Bandwidth Message"
  18. SEND_STREAM_BEGIN_MESSAGE = "Send Stream Begin Message"
  19. SEND_SET_BUFFER_LENGTH_MESSAGE = "Send Set Buffer Lengh Message"
  20. SEND_STREAM_IS_RECORDED_MESSAGE = "Send Stream Is Recorded Message"
  21. SEND_PING_REQUEST_MESSAGE = "Send Ping Request Message"
  22. SEND_PING_RESPONSE_MESSAGE = "Send Ping Response Message"
  23. SEND_CONNECT_MESSAGE = "Send Connect Message"
  24. SEND_CONNECT_RESPONSE_MESSAGE = "Send Connect Response Message"
  25. SEND_CREATE_STREAM_MESSAGE = "Send Create Stream Message"
  26. SEND_PLAY_MESSAGE = "Send Play Message"
  27. SEND_PLAY_RESPONSE_MESSAGE = "Send Play Response Message"
  28. SEND_PUBLISH_RESPONSE_MESSAGE = "Send Publish Response Message"
  29. SEND_PUBLISH_START_MESSAGE = "Send Publish Start Message"
  30. SEND_UNPUBLISH_RESPONSE_MESSAGE = "Send Unpublish Response Message"
  31. SEND_AUDIO_MESSAGE = "Send Audio Message"
  32. SEND_FULL_AUDIO_MESSAGE = "Send Full Audio Message"
  33. SEND_VIDEO_MESSAGE = "Send Video Message"
  34. SEND_FULL_VDIEO_MESSAGE = "Send Full Video Message"
  35. )
  36. type NetConnection struct {
  37. *bufio.Reader `json:"-" yaml:"-"`
  38. net.Conn `json:"-" yaml:"-"`
  39. bandwidth uint32
  40. readSeqNum uint32 // 当前读的字节
  41. writeSeqNum uint32 // 当前写的字节
  42. totalWrite uint32 // 总共写了多少字节
  43. totalRead uint32 // 总共读了多少字节
  44. writeChunkSize int
  45. readChunkSize int
  46. incommingChunks map[uint32]*Chunk
  47. objectEncoding float64
  48. appName string
  49. tmpBuf util.Buffer //用来接收/发送小数据,复用内存
  50. chunkHeader util.Buffer
  51. bytePool util.BytesPool
  52. writing atomic.Bool // false 可写,true 不可写
  53. }
  54. func NewNetConnection(conn net.Conn) *NetConnection {
  55. return &NetConnection{
  56. Conn: conn,
  57. Reader: bufio.NewReader(conn),
  58. writeChunkSize: RTMP_DEFAULT_CHUNK_SIZE,
  59. readChunkSize: RTMP_DEFAULT_CHUNK_SIZE,
  60. incommingChunks: make(map[uint32]*Chunk),
  61. bandwidth: RTMP_MAX_CHUNK_SIZE << 3,
  62. tmpBuf: make(util.Buffer, 4),
  63. chunkHeader: make(util.Buffer, 0, 16),
  64. bytePool: make(util.BytesPool, 17),
  65. }
  66. }
  67. func (conn *NetConnection) ReadFull(buf []byte) (n int, err error) {
  68. n, err = io.ReadFull(conn.Reader, buf)
  69. if err == nil {
  70. conn.readSeqNum += uint32(n)
  71. }
  72. return
  73. }
  74. func (conn *NetConnection) SendStreamID(eventType uint16, streamID uint32) (err error) {
  75. return conn.SendMessage(RTMP_MSG_USER_CONTROL, &StreamIDMessage{UserControlMessage{EventType: eventType}, streamID})
  76. }
  77. func (conn *NetConnection) SendUserControl(eventType uint16) error {
  78. return conn.SendMessage(RTMP_MSG_USER_CONTROL, &UserControlMessage{EventType: eventType})
  79. }
  80. func (conn *NetConnection) ResponseCreateStream(tid uint64, streamID uint32) error {
  81. m := &ResponseCreateStreamMessage{}
  82. m.CommandName = Response_Result
  83. m.TransactionId = tid
  84. m.StreamId = streamID
  85. return conn.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
  86. }
  87. // func (conn *NetConnection) SendCommand(message string, args any) error {
  88. // switch message {
  89. // // case SEND_SET_BUFFER_LENGTH_MESSAGE:
  90. // // if args != nil {
  91. // // return errors.New(SEND_SET_BUFFER_LENGTH_MESSAGE + ", The parameter is nil")
  92. // // }
  93. // // m := new(SetBufferMessage)
  94. // // m.EventType = RTMP_USER_SET_BUFFLEN
  95. // // m.Millisecond = 100
  96. // // m.StreamID = conn.streamID
  97. // // return conn.writeMessage(RTMP_MSG_USER_CONTROL, m)
  98. // }
  99. // return errors.New("send message no exist")
  100. // }
  101. func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
  102. head, err := conn.ReadByte()
  103. if err != nil {
  104. return nil, err
  105. }
  106. conn.readSeqNum++
  107. ChunkStreamID := uint32(head & 0x3f) // 0011 1111
  108. ChunkType := head >> 6 // 1100 0000
  109. // 如果块流ID为0,1的话,就需要计算.
  110. ChunkStreamID, err = conn.readChunkStreamID(ChunkStreamID)
  111. if err != nil {
  112. return nil, errors.New("get chunk stream id error :" + err.Error())
  113. }
  114. // println("ChunkStreamID:", ChunkStreamID, "ChunkType:", ChunkType)
  115. chunk, ok := conn.incommingChunks[ChunkStreamID]
  116. if ChunkType != 3 && ok && chunk.AVData.Length > 0 {
  117. // 如果块类型不为3,那么这个rtmp的body应该为空.
  118. return nil, errors.New("incompleteRtmpBody error")
  119. }
  120. if !ok {
  121. chunk = &Chunk{}
  122. conn.incommingChunks[ChunkStreamID] = chunk
  123. }
  124. if err = conn.readChunkType(&chunk.ChunkHeader, ChunkType); err != nil {
  125. return nil, errors.New("get chunk type error :" + err.Error())
  126. }
  127. msgLen := int(chunk.MessageLength)
  128. needRead := conn.readChunkSize
  129. if unRead := msgLen - chunk.AVData.ByteLength; unRead < needRead {
  130. needRead = unRead
  131. }
  132. mem := conn.bytePool.Get(needRead)
  133. if n, err := conn.ReadFull(mem.Value); err != nil {
  134. mem.Recycle()
  135. return nil, err
  136. } else {
  137. conn.readSeqNum += uint32(n)
  138. }
  139. if chunk.AVData.Push(mem); chunk.AVData.ByteLength == msgLen {
  140. chunk.ChunkHeader.ExtendTimestamp += chunk.ChunkHeader.Timestamp
  141. msg = chunk
  142. switch chunk.MessageTypeID {
  143. case RTMP_MSG_AUDIO, RTMP_MSG_VIDEO:
  144. default:
  145. err = GetRtmpMessage(msg, msg.AVData.ToBytes())
  146. msg.AVData.Recycle()
  147. }
  148. conn.incommingChunks[ChunkStreamID] = &Chunk{
  149. ChunkHeader: chunk.ChunkHeader,
  150. }
  151. }
  152. return
  153. }
  154. func (conn *NetConnection) readChunkStreamID(csid uint32) (chunkStreamID uint32, err error) {
  155. chunkStreamID = csid
  156. switch csid {
  157. case 0:
  158. {
  159. u8, err := conn.ReadByte()
  160. conn.readSeqNum++
  161. if err != nil {
  162. return 0, err
  163. }
  164. chunkStreamID = 64 + uint32(u8)
  165. }
  166. case 1:
  167. {
  168. u16_0, err1 := conn.ReadByte()
  169. if err1 != nil {
  170. return 0, err1
  171. }
  172. u16_1, err1 := conn.ReadByte()
  173. if err1 != nil {
  174. return 0, err1
  175. }
  176. conn.readSeqNum += 2
  177. chunkStreamID = 64 + uint32(u16_0) + (uint32(u16_1) << 8)
  178. }
  179. }
  180. return chunkStreamID, nil
  181. }
  182. func (conn *NetConnection) readChunkType(h *ChunkHeader, chunkType byte) (err error) {
  183. conn.tmpBuf.Reset()
  184. b4 := conn.tmpBuf.Malloc(4)
  185. b3 := b4[:3]
  186. if chunkType == 3 {
  187. // 3个字节的时间戳
  188. } else {
  189. // Timestamp 3 bytes
  190. if _, err = conn.ReadFull(b3); err != nil {
  191. return err
  192. }
  193. util.GetBE(b3, &h.Timestamp)
  194. if chunkType != 2 {
  195. if _, err = conn.ReadFull(b3); err != nil {
  196. return err
  197. }
  198. util.GetBE(b3, &h.MessageLength)
  199. // Message Type ID 1 bytes
  200. if h.MessageTypeID, err = conn.ReadByte(); err != nil {
  201. return err
  202. }
  203. conn.readSeqNum++
  204. if chunkType == 0 {
  205. // Message Stream ID 4bytes
  206. if _, err = conn.ReadFull(b4); err != nil { // 读取Message Stream ID
  207. return err
  208. }
  209. h.MessageStreamID = binary.LittleEndian.Uint32(b4)
  210. }
  211. }
  212. }
  213. // ExtendTimestamp 4 bytes
  214. if h.Timestamp == 0xffffff { // 对于type 0的chunk,绝对时间戳在这里表示,如果时间戳值大于等于0xffffff(16777215),该值必须是0xffffff,且时间戳扩展字段必须发送,其他情况没有要求
  215. if _, err = conn.ReadFull(b4); err != nil {
  216. return err
  217. }
  218. util.GetBE(b4, &h.Timestamp)
  219. }
  220. if chunkType == 0 {
  221. h.ExtendTimestamp = h.Timestamp
  222. h.Timestamp = 0
  223. }
  224. return nil
  225. }
  226. func (conn *NetConnection) RecvMessage() (msg *Chunk, err error) {
  227. if conn.readSeqNum >= conn.bandwidth {
  228. conn.totalRead += conn.readSeqNum
  229. conn.readSeqNum = 0
  230. err = conn.SendMessage(RTMP_MSG_ACK, Uint32Message(conn.totalRead))
  231. }
  232. for msg == nil && err == nil {
  233. if msg, err = conn.readChunk(); msg != nil && err == nil {
  234. switch msg.MessageTypeID {
  235. case RTMP_MSG_CHUNK_SIZE:
  236. conn.readChunkSize = int(msg.MsgData.(Uint32Message))
  237. RTMPPlugin.Info("msg read chunk size", zap.Int("readChunkSize", conn.readChunkSize))
  238. case RTMP_MSG_ABORT:
  239. delete(conn.incommingChunks, uint32(msg.MsgData.(Uint32Message)))
  240. case RTMP_MSG_ACK, RTMP_MSG_EDGE:
  241. case RTMP_MSG_USER_CONTROL:
  242. if _, ok := msg.MsgData.(*PingRequestMessage); ok {
  243. conn.SendUserControl(RTMP_USER_PING_RESPONSE)
  244. }
  245. case RTMP_MSG_ACK_SIZE:
  246. conn.bandwidth = uint32(msg.MsgData.(Uint32Message))
  247. case RTMP_MSG_BANDWIDTH:
  248. conn.bandwidth = msg.MsgData.(*SetPeerBandwidthMessage).AcknowledgementWindowsize
  249. case RTMP_MSG_AMF0_COMMAND, RTMP_MSG_AUDIO, RTMP_MSG_VIDEO:
  250. return msg, err
  251. }
  252. }
  253. }
  254. return
  255. }
  256. func (conn *NetConnection) SendMessage(t byte, msg RtmpMessage) (err error) {
  257. if conn == nil {
  258. return errors.New("connection is nil")
  259. }
  260. if conn.writeSeqNum > conn.bandwidth {
  261. conn.totalWrite += conn.writeSeqNum
  262. conn.writeSeqNum = 0
  263. err = conn.SendMessage(RTMP_MSG_ACK, Uint32Message(conn.totalWrite))
  264. err = conn.SendStreamID(RTMP_USER_PING_REQUEST, 0)
  265. }
  266. for !conn.writing.CompareAndSwap(false, true) {
  267. runtime.Gosched()
  268. }
  269. defer conn.writing.Store(false)
  270. conn.tmpBuf.Reset()
  271. amf := util.AMF{conn.tmpBuf}
  272. if conn.objectEncoding == 0 {
  273. msg.Encode(&amf)
  274. } else {
  275. amf := util.AMF3{AMF: amf}
  276. msg.Encode(&amf)
  277. }
  278. conn.tmpBuf = amf.Buffer
  279. head := newChunkHeader(t)
  280. head.MessageLength = uint32(conn.tmpBuf.Len())
  281. if sid, ok := msg.(HaveStreamID); ok {
  282. head.MessageStreamID = sid.GetStreamID()
  283. }
  284. head.WriteTo(RTMP_CHUNK_HEAD_12, &conn.chunkHeader)
  285. for _, chunk := range conn.tmpBuf.Split(conn.writeChunkSize) {
  286. conn.sendChunk(chunk)
  287. }
  288. return nil
  289. }
  290. func (conn *NetConnection) sendChunk(writeBuffer ...[]byte) error {
  291. if n, err := conn.Write(conn.chunkHeader); err != nil {
  292. return err
  293. } else {
  294. conn.writeSeqNum += uint32(n)
  295. }
  296. buf := net.Buffers(writeBuffer)
  297. n, err := buf.WriteTo(conn)
  298. conn.writeSeqNum += uint32(n)
  299. return err
  300. }