session.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package srtp
  2. import (
  3. "net"
  4. "time"
  5. "github.com/pion/rtcp"
  6. "github.com/pion/rtp"
  7. "github.com/pion/srtp/v2"
  8. )
  9. type Session struct {
  10. Local *Endpoint
  11. Remote *Endpoint
  12. OnReadRTP func(packet *rtp.Packet)
  13. Recv int // bytes recv
  14. Send int // bytes send
  15. conn net.PacketConn // local conn endpoint
  16. PayloadType uint8
  17. RTCPInterval time.Duration
  18. senderRTCP rtcp.SenderReport
  19. senderTime time.Time
  20. }
  21. type Endpoint struct {
  22. Addr string
  23. Port uint16
  24. MasterKey []byte
  25. MasterSalt []byte
  26. SSRC uint32
  27. addr net.Addr
  28. srtp *srtp.Context
  29. }
  30. func (e *Endpoint) init() (err error) {
  31. e.addr = &net.UDPAddr{IP: net.ParseIP(e.Addr), Port: int(e.Port)}
  32. e.srtp, err = srtp.CreateContext(e.MasterKey, e.MasterSalt, profile(e.MasterKey))
  33. return
  34. }
  35. func profile(key []byte) srtp.ProtectionProfile {
  36. switch len(key) {
  37. case 16:
  38. return srtp.ProtectionProfileAes128CmHmacSha1_80
  39. //case 32:
  40. // return srtp.ProtectionProfileAes256CmHmacSha1_80
  41. }
  42. return 0
  43. }
  44. func (s *Session) init() error {
  45. if err := s.Local.init(); err != nil {
  46. return err
  47. }
  48. if err := s.Remote.init(); err != nil {
  49. return err
  50. }
  51. s.senderRTCP.SSRC = s.Local.SSRC
  52. s.senderTime = time.Now().Add(s.RTCPInterval)
  53. return nil
  54. }
  55. func (s *Session) WriteRTP(packet *rtp.Packet) (int, error) {
  56. if s.Local.srtp == nil {
  57. return 0, nil // before init call
  58. }
  59. if now := time.Now(); now.After(s.senderTime) {
  60. s.senderRTCP.NTPTime = uint64(now.UnixNano())
  61. s.senderTime = now.Add(s.RTCPInterval)
  62. _, _ = s.WriteRTCP(&s.senderRTCP)
  63. }
  64. clone := rtp.Packet{
  65. Header: rtp.Header{
  66. Version: 2,
  67. Marker: packet.Marker,
  68. PayloadType: s.PayloadType,
  69. SequenceNumber: packet.SequenceNumber,
  70. Timestamp: packet.Timestamp,
  71. SSRC: s.Local.SSRC,
  72. },
  73. Payload: packet.Payload,
  74. }
  75. b, err := clone.Marshal()
  76. if err != nil {
  77. return 0, err
  78. }
  79. s.senderRTCP.PacketCount++
  80. s.senderRTCP.RTPTime = clone.Timestamp
  81. s.senderRTCP.OctetCount += uint32(len(clone.Payload))
  82. if b, err = s.Local.srtp.EncryptRTP(nil, b, nil); err != nil {
  83. return 0, err
  84. }
  85. return s.conn.WriteTo(b, s.Remote.addr)
  86. }
  87. func (s *Session) WriteRTCP(packet rtcp.Packet) (int, error) {
  88. b, err := packet.Marshal()
  89. if err != nil {
  90. return 0, err
  91. }
  92. b, err = s.Local.srtp.EncryptRTCP(nil, b, nil)
  93. if err != nil {
  94. return 0, err
  95. }
  96. return s.conn.WriteTo(b, s.Remote.addr)
  97. }
  98. func (s *Session) ReadRTP(b []byte) {
  99. packet := &rtp.Packet{}
  100. b, err := s.Remote.srtp.DecryptRTP(nil, b, &packet.Header)
  101. if err != nil {
  102. return
  103. }
  104. if err = packet.Unmarshal(b); err != nil {
  105. return
  106. }
  107. if s.OnReadRTP != nil {
  108. s.OnReadRTP(packet)
  109. }
  110. }
  111. func (s *Session) ReadRTCP(b []byte) {
  112. header := rtcp.Header{}
  113. b, err := s.Remote.srtp.DecryptRTCP(nil, b, &header)
  114. if err != nil {
  115. return
  116. }
  117. //packets, err := rtcp.Unmarshal(b)
  118. //if err != nil {
  119. // return
  120. //}
  121. //if report, ok := packets[0].(*rtcp.SenderReport); ok {
  122. // log.Printf("[srtp] rtcp type=%d report=%v", header.Type, report)
  123. //}
  124. if header.Type != rtcp.TypeSenderReport {
  125. return
  126. }
  127. receiverRTCP := rtcp.ReceiverReport{SSRC: s.Local.SSRC}
  128. _, _ = s.WriteRTCP(&receiverRTCP)
  129. }