server.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package chat
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "sync"
  8. "time"
  9. "github.com/prometheus/client_golang/prometheus"
  10. "github.com/prometheus/client_golang/prometheus/promauto"
  11. log "github.com/sirupsen/logrus"
  12. "github.com/gorilla/websocket"
  13. "github.com/owncast/owncast/models"
  14. "github.com/owncast/owncast/services/config"
  15. "github.com/owncast/owncast/services/geoip"
  16. "github.com/owncast/owncast/services/status"
  17. "github.com/owncast/owncast/services/webhooks"
  18. "github.com/owncast/owncast/storage/chatrepository"
  19. "github.com/owncast/owncast/storage/configrepository"
  20. "github.com/owncast/owncast/storage/userrepository"
  21. "github.com/owncast/owncast/utils"
  22. )
  23. var _server *Server
  24. // Server represents an instance of the chat server.
  25. type Server struct {
  26. clients map[uint]*Client
  27. // send outbound message payload to all clients
  28. outbound chan []byte
  29. // receive inbound message payload from all clients
  30. inbound chan chatClientEvent
  31. // unregister requests from clients.
  32. unregister chan uint // the ChatClient id
  33. geoipClient *geoip.Client
  34. // a map of user IDs and timers that fire for chat part messages.
  35. userPartedTimers map[string]*time.Ticker
  36. seq uint
  37. maxSocketConnectionLimit int64
  38. chatMessagesSentCounter prometheus.Gauge
  39. // a map of user IDs and when they last were active.
  40. lastSeenCache map[string]time.Time
  41. mu sync.RWMutex
  42. config *config.Config
  43. configRepository *configrepository.SqlConfigRepository
  44. chatRepository *chatrepository.ChatRepository
  45. }
  46. // NewChat will return a new instance of the chat server.
  47. func NewChat() *Server {
  48. server := &Server{
  49. clients: map[uint]*Client{},
  50. outbound: make(chan []byte),
  51. inbound: make(chan chatClientEvent),
  52. unregister: make(chan uint),
  53. maxSocketConnectionLimit: 100, // TODO: Set this properly!
  54. lastSeenCache: map[string]time.Time{},
  55. geoipClient: geoip.NewClient(),
  56. userPartedTimers: map[string]*time.Ticker{},
  57. config: config.Get(),
  58. configRepository: configrepository.Get(),
  59. chatRepository: chatrepository.Get(),
  60. }
  61. server.chatMessagesSentCounter = promauto.NewGauge(prometheus.GaugeOpts{
  62. Name: "total_chat_message_count",
  63. Help: "The number of chat messages incremented over time.",
  64. ConstLabels: map[string]string{
  65. "version": server.config.VersionNumber,
  66. "host": server.configRepository.GetServerURL(),
  67. },
  68. })
  69. return server
  70. }
  71. // Run will start the chat server.
  72. func (s *Server) Run() {
  73. for {
  74. select {
  75. case clientID := <-s.unregister:
  76. if client, ok := s.clients[clientID]; ok {
  77. s.handleClientDisconnected(client)
  78. s.mu.Lock()
  79. delete(s.clients, clientID)
  80. s.mu.Unlock()
  81. }
  82. case message := <-s.inbound:
  83. s.eventReceived(message)
  84. }
  85. }
  86. }
  87. // Addclient registers new connection as a User.
  88. func (s *Server) Addclient(conn *websocket.Conn, user *models.User, accessToken string, userAgent string, ipAddress string) *Client {
  89. client := &Client{
  90. server: s,
  91. conn: conn,
  92. User: user,
  93. IPAddress: ipAddress,
  94. accessToken: accessToken,
  95. send: make(chan []byte, 256),
  96. UserAgent: userAgent,
  97. ConnectedAt: time.Now(),
  98. }
  99. shouldSendJoinedMessages := s.configRepository.GetChatJoinPartMessagesEnabled()
  100. s.mu.Lock()
  101. {
  102. // If there is a pending disconnect timer then clear it.
  103. // Do not send user joined message if enough time hasn't passed where the
  104. // user chat part message hasn't been sent yet.
  105. if ticker, ok := s.userPartedTimers[user.ID]; ok {
  106. ticker.Stop()
  107. delete(s.userPartedTimers, user.ID)
  108. }
  109. client.Id = s.seq
  110. s.clients[client.Id] = client
  111. s.seq++
  112. }
  113. s.mu.Unlock()
  114. log.Traceln("Adding client", client.Id, "total count:", len(s.clients))
  115. go client.writePump()
  116. go client.readPump()
  117. client.sendConnectedClientInfo()
  118. st := status.Get()
  119. if st.Online {
  120. if shouldSendJoinedMessages {
  121. s.sendUserJoinedMessage(client)
  122. }
  123. s.sendWelcomeMessageToClient(client)
  124. }
  125. // Asynchronously, optionally, fetch GeoIP configRepository.
  126. go func(client *Client) {
  127. client.Geo = s.geoipClient.GetGeoFromIP(ipAddress)
  128. }(client)
  129. return client
  130. }
  131. func (s *Server) sendUserJoinedMessage(c *Client) {
  132. userJoinedEvent := models.UserJoinedEvent{}
  133. userJoinedEvent.SetDefaults()
  134. userJoinedEvent.User = c.User
  135. userJoinedEvent.ClientID = c.Id
  136. if err := s.Broadcast(userJoinedEvent.GetBroadcastPayload()); err != nil {
  137. log.Errorln("error adding client to chat server", err)
  138. }
  139. // Send chat user joined webhook
  140. webhookManager := webhooks.Get()
  141. webhookManager.SendChatEventUserJoined(userJoinedEvent)
  142. }
  143. // getClientsForUser will return chat connections that are owned by a specific user.
  144. func (s *Server) GetClientsForUser(userID string) ([]*Client, error) {
  145. s.mu.Lock()
  146. defer s.mu.Unlock()
  147. clients := map[string][]*Client{}
  148. for _, client := range s.clients {
  149. clients[client.User.ID] = append(clients[client.User.ID], client)
  150. }
  151. if _, exists := clients[userID]; !exists {
  152. return nil, errors.New("no connections for user found")
  153. }
  154. return clients[userID], nil
  155. }
  156. func (s *Server) handleClientDisconnected(c *Client) {
  157. if _, ok := s.clients[c.Id]; ok {
  158. log.Debugln("Deleting", c.Id)
  159. delete(s.clients, c.Id)
  160. }
  161. additionalClientCheck, _ := s.GetClientsForUser(c.User.ID)
  162. if len(additionalClientCheck) > 0 {
  163. // This user is still connected to chat with another client.
  164. return
  165. }
  166. s.userPartedTimers[c.User.ID] = time.NewTicker(10 * time.Second)
  167. go func() {
  168. <-s.userPartedTimers[c.User.ID].C
  169. s.sendUserPartedMessage(c)
  170. }()
  171. }
  172. func (s *Server) sendUserPartedMessage(c *Client) {
  173. s.userPartedTimers[c.User.ID].Stop()
  174. delete(s.userPartedTimers, c.User.ID)
  175. userPartEvent := UserPartEvent{}
  176. userPartEvent.SetDefaults()
  177. userPartEvent.User = c.User
  178. userPartEvent.ClientID = c.Id
  179. // If part messages are disabled.
  180. if s.configRepository.GetChatJoinPartMessagesEnabled() {
  181. if err := s.Broadcast(userPartEvent.GetBroadcastPayload()); err != nil {
  182. log.Errorln("error sending chat part message", err)
  183. }
  184. }
  185. // Send chat user joined webhook
  186. webhooks.SendChatEventUserParted(userPartEvent)
  187. }
  188. // HandleClientConnection is fired when a single client connects to the websocket.
  189. func (s *Server) HandleClientConnection(w http.ResponseWriter, r *http.Request) {
  190. cr := configrepository.Get()
  191. chatRepository := chatrepository.Get()
  192. if cr.GetChatDisabled() {
  193. _, _ = w.Write([]byte(models.ChatDisabled))
  194. return
  195. }
  196. ipAddress := utils.GetIPAddressFromRequest(r)
  197. // Check if this client's IP address is banned. If so send a rejection.
  198. if blocked, err := chatRepository.IsIPAddressBanned(ipAddress); blocked {
  199. log.Debugln("Client ip address has been blocked. Rejecting.")
  200. w.WriteHeader(http.StatusForbidden)
  201. return
  202. } else if err != nil {
  203. log.Errorln("error determining if IP address is blocked: ", err)
  204. }
  205. // Limit concurrent chat connections
  206. if int64(len(s.clients)) >= s.maxSocketConnectionLimit {
  207. log.Warnln("rejecting incoming client connection as it exceeds the max client count of", s.maxSocketConnectionLimit)
  208. _, _ = w.Write([]byte(models.ErrorMaxConnectionsExceeded))
  209. return
  210. }
  211. // To allow dev web environments to connect.
  212. upgrader.CheckOrigin = func(r *http.Request) bool {
  213. return true
  214. }
  215. conn, err := upgrader.Upgrade(w, r, nil)
  216. if err != nil {
  217. log.Debugln(err)
  218. return
  219. }
  220. accessToken := r.URL.Query().Get("accessToken")
  221. if accessToken == "" {
  222. log.Errorln("Access token is required")
  223. // Return HTTP status code
  224. _ = conn.Close()
  225. return
  226. }
  227. userRepository := userrepository.Get()
  228. // A user is required to use the websocket
  229. user := userRepository.GetUserByToken(accessToken)
  230. if user == nil {
  231. // Send error that registration is required
  232. _ = conn.WriteJSON(models.EventPayload{
  233. "type": models.ErrorNeedsRegistration,
  234. })
  235. _ = conn.Close()
  236. return
  237. }
  238. // User is disabled therefore we should disconnect.
  239. if user.DisabledAt != nil {
  240. log.Traceln("Disabled user", user.ID, user.DisplayName, "rejected")
  241. _ = conn.WriteJSON(models.EventPayload{
  242. "type": models.ErrorUserDisabled,
  243. })
  244. _ = conn.Close()
  245. return
  246. }
  247. userAgent := r.UserAgent()
  248. s.Addclient(conn, user, accessToken, userAgent, ipAddress)
  249. }
  250. // Broadcast sends message to all connected clients.
  251. func (s *Server) Broadcast(payload models.EventPayload) error {
  252. data, err := json.Marshal(payload)
  253. if err != nil {
  254. return err
  255. }
  256. s.mu.RLock()
  257. defer s.mu.RUnlock()
  258. for _, client := range s.clients {
  259. if client == nil {
  260. continue
  261. }
  262. select {
  263. case client.send <- data:
  264. default:
  265. go client.close()
  266. }
  267. }
  268. return nil
  269. }
  270. // Send will send a single payload to a single connected client.
  271. func (s *Server) Send(payload models.EventPayload, client *Client) {
  272. data, err := json.Marshal(payload)
  273. if err != nil {
  274. log.Errorln(err)
  275. return
  276. }
  277. client.send <- data
  278. }
  279. // DisconnectClients will forcefully disconnect all clients belonging to a user by ID.
  280. func (s *Server) DisconnectClients(clients []*Client) {
  281. for _, client := range clients {
  282. log.Traceln("Disconnecting client", client.User.ID, "owned by", client.User.DisplayName)
  283. go func(client *Client) {
  284. event := models.UserDisabledEvent{}
  285. event.SetDefaults()
  286. // Send this disabled event specifically to this single connected client
  287. // to let them know they've been banned.
  288. s.Send(event.GetBroadcastPayload(), client)
  289. // Give the socket time to send out the above message.
  290. // Unfortunately I don't know of any way to get a real callback to know when
  291. // the message was successfully sent, so give it a couple seconds.
  292. time.Sleep(2 * time.Second)
  293. // Forcefully disconnect if still valid.
  294. if client != nil {
  295. client.close()
  296. }
  297. }(client)
  298. }
  299. }
  300. func (s *Server) eventReceived(event chatClientEvent) {
  301. c := event.client
  302. u := c.User
  303. cr := configrepository.Get()
  304. // If established chat user only mode is enabled and the user is not old
  305. // enough then reject this event and send them an informative message.
  306. if u != nil && cr.GetChatEstbalishedUsersOnlyMode() && time.Since(event.client.User.CreatedAt) < config.GetDefaults().ChatEstablishedUserModeTimeDuration && !u.IsModerator() {
  307. s.sendActionToClient(c, "You have not been an established chat participant long enough to take part in chat. Please enjoy the stream and try again later.")
  308. return
  309. }
  310. var typecheck map[string]interface{}
  311. if err := json.Unmarshal(event.data, &typecheck); err != nil {
  312. log.Debugln(err)
  313. }
  314. eventType := typecheck["type"]
  315. switch eventType {
  316. case models.MessageSent:
  317. s.userMessageSent(event)
  318. case models.UserNameChanged:
  319. s.userNameChanged(event)
  320. case models.UserColorChanged:
  321. s.userColorChanged(event)
  322. default:
  323. log.Debugln(logSanitize(fmt.Sprint(eventType)), "event not found:", logSanitize(fmt.Sprint(typecheck)))
  324. }
  325. }
  326. func (s *Server) sendWelcomeMessageToClient(c *Client) {
  327. // Add an artificial delay so people notice this message come in.
  328. time.Sleep(7 * time.Second)
  329. cr := configrepository.Get()
  330. welcomeMessage := utils.RenderSimpleMarkdown(cr.GetServerWelcomeMessage())
  331. if welcomeMessage != "" {
  332. s.sendSystemMessageToClient(c, welcomeMessage)
  333. }
  334. }
  335. func (s *Server) sendAllWelcomeMessage() {
  336. cr := configrepository.Get()
  337. welcomeMessage := utils.RenderSimpleMarkdown(cr.GetServerWelcomeMessage())
  338. if welcomeMessage != "" {
  339. clientMessage := SystemMessageEvent{
  340. Event: models.Event{},
  341. MessageEvent: MessageEvent{
  342. Body: welcomeMessage,
  343. },
  344. }
  345. clientMessage.SetDefaults()
  346. clientMessage.DisplayName = s.configRepository.GetServerName()
  347. _ = s.Broadcast(clientMessage.GetBroadcastPayload())
  348. }
  349. }
  350. func (s *Server) sendSystemMessageToClient(c *Client, message string) {
  351. clientMessage := SystemMessageEvent{
  352. Event: models.Event{},
  353. MessageEvent: MessageEvent{
  354. Body: message,
  355. },
  356. }
  357. clientMessage.SetDefaults()
  358. clientMessage.RenderBody()
  359. clientMessage.DisplayName = s.configRepository.GetServerName()
  360. s.Send(clientMessage.GetBroadcastPayload(), c)
  361. }
  362. func (s *Server) sendActionToClient(c *Client, message string) {
  363. clientMessage := ActionEvent{
  364. MessageEvent: MessageEvent{
  365. Body: message,
  366. },
  367. Event: models.Event{
  368. Type: models.ChatActionSent,
  369. },
  370. }
  371. clientMessage.SetDefaults()
  372. clientMessage.RenderBody()
  373. s.Send(clientMessage.GetBroadcastPayload(), c)
  374. }