123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- package srtp
- import (
- "net"
- "time"
- "github.com/pion/rtcp"
- "github.com/pion/rtp"
- "github.com/pion/srtp/v2"
- )
- type Session struct {
- Local *Endpoint
- Remote *Endpoint
- OnReadRTP func(packet *rtp.Packet)
- Recv int // bytes recv
- Send int // bytes send
- conn net.PacketConn // local conn endpoint
- PayloadType uint8
- RTCPInterval time.Duration
- senderRTCP rtcp.SenderReport
- senderTime time.Time
- }
- type Endpoint struct {
- Addr string
- Port uint16
- MasterKey []byte
- MasterSalt []byte
- SSRC uint32
- addr net.Addr
- srtp *srtp.Context
- }
- func (e *Endpoint) init() (err error) {
- e.addr = &net.UDPAddr{IP: net.ParseIP(e.Addr), Port: int(e.Port)}
- e.srtp, err = srtp.CreateContext(e.MasterKey, e.MasterSalt, profile(e.MasterKey))
- return
- }
- func profile(key []byte) srtp.ProtectionProfile {
- switch len(key) {
- case 16:
- return srtp.ProtectionProfileAes128CmHmacSha1_80
- //case 32:
- // return srtp.ProtectionProfileAes256CmHmacSha1_80
- }
- return 0
- }
- func (s *Session) init() error {
- if err := s.Local.init(); err != nil {
- return err
- }
- if err := s.Remote.init(); err != nil {
- return err
- }
- s.senderRTCP.SSRC = s.Local.SSRC
- s.senderTime = time.Now().Add(s.RTCPInterval)
- return nil
- }
- func (s *Session) WriteRTP(packet *rtp.Packet) (int, error) {
- if s.Local.srtp == nil {
- return 0, nil // before init call
- }
- if now := time.Now(); now.After(s.senderTime) {
- s.senderRTCP.NTPTime = uint64(now.UnixNano())
- s.senderTime = now.Add(s.RTCPInterval)
- _, _ = s.WriteRTCP(&s.senderRTCP)
- }
- clone := rtp.Packet{
- Header: rtp.Header{
- Version: 2,
- Marker: packet.Marker,
- PayloadType: s.PayloadType,
- SequenceNumber: packet.SequenceNumber,
- Timestamp: packet.Timestamp,
- SSRC: s.Local.SSRC,
- },
- Payload: packet.Payload,
- }
- b, err := clone.Marshal()
- if err != nil {
- return 0, err
- }
- s.senderRTCP.PacketCount++
- s.senderRTCP.RTPTime = clone.Timestamp
- s.senderRTCP.OctetCount += uint32(len(clone.Payload))
- if b, err = s.Local.srtp.EncryptRTP(nil, b, nil); err != nil {
- return 0, err
- }
- return s.conn.WriteTo(b, s.Remote.addr)
- }
- func (s *Session) WriteRTCP(packet rtcp.Packet) (int, error) {
- b, err := packet.Marshal()
- if err != nil {
- return 0, err
- }
- b, err = s.Local.srtp.EncryptRTCP(nil, b, nil)
- if err != nil {
- return 0, err
- }
- return s.conn.WriteTo(b, s.Remote.addr)
- }
- func (s *Session) ReadRTP(b []byte) {
- packet := &rtp.Packet{}
- b, err := s.Remote.srtp.DecryptRTP(nil, b, &packet.Header)
- if err != nil {
- return
- }
- if err = packet.Unmarshal(b); err != nil {
- return
- }
- if s.OnReadRTP != nil {
- s.OnReadRTP(packet)
- }
- }
- func (s *Session) ReadRTCP(b []byte) {
- header := rtcp.Header{}
- b, err := s.Remote.srtp.DecryptRTCP(nil, b, &header)
- if err != nil {
- return
- }
- //packets, err := rtcp.Unmarshal(b)
- //if err != nil {
- // return
- //}
- //if report, ok := packets[0].(*rtcp.SenderReport); ok {
- // log.Printf("[srtp] rtcp type=%d report=%v", header.Type, report)
- //}
- if header.Type != rtcp.TypeSenderReport {
- return
- }
- receiverRTCP := rtcp.ReceiverReport{SSRC: s.Local.SSRC}
- _, _ = s.WriteRTCP(&receiverRTCP)
- }
|