server.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. package chat
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "sync"
  7. "time"
  8. log "github.com/sirupsen/logrus"
  9. "github.com/gorilla/websocket"
  10. "github.com/owncast/owncast/config"
  11. "github.com/owncast/owncast/core/chat/events"
  12. "github.com/owncast/owncast/core/data"
  13. "github.com/owncast/owncast/core/webhooks"
  14. "github.com/owncast/owncast/models"
  15. "github.com/owncast/owncast/persistence/userrepository"
  16. "github.com/owncast/owncast/services/geoip"
  17. "github.com/owncast/owncast/utils"
  18. )
  19. var _server *Server
  20. // Server represents an instance of the chat server.
  21. type Server struct {
  22. clients map[uint]*Client
  23. // send outbound message payload to all clients
  24. outbound chan []byte
  25. // receive inbound message payload from all clients
  26. inbound chan chatClientEvent
  27. // unregister requests from clients.
  28. unregister chan uint // the ChatClient id
  29. geoipClient *geoip.Client
  30. // a map of user IDs and timers that fire for chat part messages.
  31. userPartedTimers map[string]*time.Ticker
  32. seq uint
  33. maxSocketConnectionLimit uint64
  34. mu sync.RWMutex
  35. }
  36. // NewChat will return a new instance of the chat server.
  37. func NewChat() *Server {
  38. maximumConcurrentConnectionLimit := getMaximumConcurrentConnectionLimit()
  39. setSystemConcurrentConnectionLimit(maximumConcurrentConnectionLimit)
  40. server := &Server{
  41. clients: map[uint]*Client{},
  42. outbound: make(chan []byte),
  43. inbound: make(chan chatClientEvent),
  44. unregister: make(chan uint),
  45. maxSocketConnectionLimit: maximumConcurrentConnectionLimit,
  46. geoipClient: geoip.NewClient(),
  47. userPartedTimers: map[string]*time.Ticker{},
  48. }
  49. return server
  50. }
  51. // Run will start the chat server.
  52. func (s *Server) Run() {
  53. for {
  54. select {
  55. case clientID := <-s.unregister:
  56. if client, ok := s.clients[clientID]; ok {
  57. s.handleClientDisconnected(client)
  58. s.mu.Lock()
  59. delete(s.clients, clientID)
  60. s.mu.Unlock()
  61. }
  62. case message := <-s.inbound:
  63. s.eventReceived(message)
  64. }
  65. }
  66. }
  67. // Addclient registers new connection as a User.
  68. func (s *Server) Addclient(conn *websocket.Conn, user *models.User, accessToken string, userAgent string, ipAddress string) *Client {
  69. client := &Client{
  70. server: s,
  71. conn: conn,
  72. User: user,
  73. IPAddress: ipAddress,
  74. accessToken: accessToken,
  75. send: make(chan []byte, 256),
  76. UserAgent: userAgent,
  77. ConnectedAt: time.Now(),
  78. }
  79. shouldSendJoinedMessages := data.GetChatJoinPartMessagesEnabled()
  80. // If there are existing clients connected for this user do not send
  81. // a user joined message. Do not put this under a mutex, as
  82. // GetClientsForUser already has a lock.
  83. if existingConnectedClients, _ := GetClientsForUser(user.ID); len(existingConnectedClients) > 0 {
  84. shouldSendJoinedMessages = false
  85. }
  86. s.mu.Lock()
  87. {
  88. // If there is a pending disconnect timer then clear it.
  89. // Do not send user joined message if enough time hasn't passed where the
  90. // user chat part message hasn't been sent yet.
  91. if ticker, ok := s.userPartedTimers[user.ID]; ok {
  92. ticker.Stop()
  93. delete(s.userPartedTimers, user.ID)
  94. shouldSendJoinedMessages = false
  95. }
  96. client.Id = s.seq
  97. s.clients[client.Id] = client
  98. s.seq++
  99. }
  100. s.mu.Unlock()
  101. log.Traceln("Adding client", client.Id, "total count:", len(s.clients))
  102. go client.writePump()
  103. go client.readPump()
  104. client.sendConnectedClientInfo()
  105. if getStatus().Online {
  106. if shouldSendJoinedMessages {
  107. s.sendUserJoinedMessage(client)
  108. }
  109. s.sendWelcomeMessageToClient(client)
  110. }
  111. // Asynchronously, optionally, fetch GeoIP data.
  112. go func(client *Client) {
  113. client.Geo = s.geoipClient.GetGeoFromIP(ipAddress)
  114. }(client)
  115. return client
  116. }
  117. func (s *Server) sendUserJoinedMessage(c *Client) {
  118. userJoinedEvent := events.UserJoinedEvent{}
  119. userJoinedEvent.SetDefaults()
  120. userJoinedEvent.User = c.User
  121. userJoinedEvent.ClientID = c.Id
  122. if err := s.Broadcast(userJoinedEvent.GetBroadcastPayload()); err != nil {
  123. log.Errorln("error adding client to chat server", err)
  124. }
  125. // Send chat user joined webhook
  126. webhooks.SendChatEventUserJoined(userJoinedEvent)
  127. }
  128. func (s *Server) handleClientDisconnected(c *Client) {
  129. if _, ok := s.clients[c.Id]; ok {
  130. log.Debugln("Deleting", c.Id)
  131. delete(s.clients, c.Id)
  132. }
  133. additionalClientCheck, _ := GetClientsForUser(c.User.ID)
  134. if len(additionalClientCheck) > 0 {
  135. // This user is still connected to chat with another client.
  136. return
  137. }
  138. s.userPartedTimers[c.User.ID] = time.NewTicker(10 * time.Second)
  139. go func() {
  140. <-s.userPartedTimers[c.User.ID].C
  141. s.sendUserPartedMessage(c)
  142. }()
  143. }
  144. func (s *Server) sendUserPartedMessage(c *Client) {
  145. s.userPartedTimers[c.User.ID].Stop()
  146. delete(s.userPartedTimers, c.User.ID)
  147. userPartEvent := events.UserPartEvent{}
  148. userPartEvent.SetDefaults()
  149. userPartEvent.User = c.User
  150. userPartEvent.ClientID = c.Id
  151. // If part messages are disabled.
  152. if data.GetChatJoinPartMessagesEnabled() {
  153. if err := s.Broadcast(userPartEvent.GetBroadcastPayload()); err != nil {
  154. log.Errorln("error sending chat part message", err)
  155. }
  156. }
  157. // Send chat user joined webhook
  158. webhooks.SendChatEventUserParted(userPartEvent)
  159. }
  160. // HandleClientConnection is fired when a single client connects to the websocket.
  161. func (s *Server) HandleClientConnection(w http.ResponseWriter, r *http.Request) {
  162. if data.GetChatDisabled() {
  163. _, _ = w.Write([]byte(events.ChatDisabled))
  164. return
  165. }
  166. ipAddress := utils.GetIPAddressFromRequest(r)
  167. // Check if this client's IP address is banned. If so send a rejection.
  168. if blocked, err := data.IsIPAddressBanned(ipAddress); blocked {
  169. log.Debugln("Client ip address has been blocked. Rejecting.")
  170. w.WriteHeader(http.StatusForbidden)
  171. return
  172. } else if err != nil {
  173. log.Errorln("error determining if IP address is blocked: ", err)
  174. }
  175. // Limit concurrent chat connections
  176. if uint64(len(s.clients)) >= s.maxSocketConnectionLimit {
  177. log.Warnln("rejecting incoming client connection as it exceeds the max client count of", s.maxSocketConnectionLimit)
  178. _, _ = w.Write([]byte(events.ErrorMaxConnectionsExceeded))
  179. return
  180. }
  181. // To allow dev web environments to connect.
  182. upgrader.CheckOrigin = func(r *http.Request) bool {
  183. return true
  184. }
  185. conn, err := upgrader.Upgrade(w, r, nil)
  186. if err != nil {
  187. log.Debugln(err)
  188. return
  189. }
  190. accessToken := r.URL.Query().Get("accessToken")
  191. if accessToken == "" {
  192. log.Errorln("Access token is required")
  193. // Return HTTP status code
  194. _ = conn.Close()
  195. return
  196. }
  197. userRepository := userrepository.Get()
  198. // A user is required to use the websocket
  199. user := userRepository.GetUserByToken(accessToken)
  200. if user == nil {
  201. // Send error that registration is required
  202. _ = conn.WriteJSON(events.EventPayload{
  203. "type": events.ErrorNeedsRegistration,
  204. })
  205. _ = conn.Close()
  206. return
  207. }
  208. // User is disabled therefore we should disconnect.
  209. if user.DisabledAt != nil {
  210. log.Traceln("Disabled user", user.ID, user.DisplayName, "rejected")
  211. _ = conn.WriteJSON(events.EventPayload{
  212. "type": events.ErrorUserDisabled,
  213. })
  214. _ = conn.Close()
  215. return
  216. }
  217. userAgent := r.UserAgent()
  218. s.Addclient(conn, user, accessToken, userAgent, ipAddress)
  219. }
  220. // Broadcast sends message to all connected clients.
  221. func (s *Server) Broadcast(payload events.EventPayload) error {
  222. data, err := json.Marshal(payload)
  223. if err != nil {
  224. return err
  225. }
  226. s.mu.RLock()
  227. defer s.mu.RUnlock()
  228. for _, client := range s.clients {
  229. if client == nil {
  230. continue
  231. }
  232. select {
  233. case client.send <- data:
  234. default:
  235. go client.close()
  236. }
  237. }
  238. return nil
  239. }
  240. // Send will send a single payload to a single connected client.
  241. func (s *Server) Send(payload events.EventPayload, client *Client) {
  242. data, err := json.Marshal(payload)
  243. if err != nil {
  244. log.Errorln(err)
  245. return
  246. }
  247. client.send <- data
  248. }
  249. // DisconnectClients will forcefully disconnect all clients belonging to a user by ID.
  250. func (s *Server) DisconnectClients(clients []*Client) {
  251. for _, client := range clients {
  252. log.Traceln("Disconnecting client", client.User.ID, "owned by", client.User.DisplayName)
  253. go func(client *Client) {
  254. event := events.UserDisabledEvent{}
  255. event.SetDefaults()
  256. // Send this disabled event specifically to this single connected client
  257. // to let them know they've been banned.
  258. _server.Send(event.GetBroadcastPayload(), client)
  259. // Give the socket time to send out the above message.
  260. // Unfortunately I don't know of any way to get a real callback to know when
  261. // the message was successfully sent, so give it a couple seconds.
  262. time.Sleep(2 * time.Second)
  263. // Forcefully disconnect if still valid.
  264. if client != nil {
  265. client.close()
  266. }
  267. }(client)
  268. }
  269. }
  270. // SendConnectedClientInfoToUser will find all the connected clients assigned to a user
  271. // and re-send each the connected client info.
  272. func SendConnectedClientInfoToUser(userID string) error {
  273. clients, err := GetClientsForUser(userID)
  274. if err != nil {
  275. return err
  276. }
  277. userRepository := userrepository.Get()
  278. // Get an updated reference to the user.
  279. user := userRepository.GetUserByID(userID)
  280. if user == nil {
  281. return fmt.Errorf("user not found")
  282. }
  283. if err != nil {
  284. return err
  285. }
  286. for _, client := range clients {
  287. // Update the client's reference to its user.
  288. client.User = user
  289. // Send the update to the client.
  290. client.sendConnectedClientInfo()
  291. }
  292. return nil
  293. }
  294. // SendActionToUser will send system action text to all connected clients
  295. // assigned to a user ID.
  296. func SendActionToUser(userID string, text string) error {
  297. clients, err := GetClientsForUser(userID)
  298. if err != nil {
  299. return err
  300. }
  301. for _, client := range clients {
  302. _server.sendActionToClient(client, text)
  303. }
  304. return nil
  305. }
  306. func (s *Server) eventReceived(event chatClientEvent) {
  307. c := event.client
  308. u := c.User
  309. // If established chat user only mode is enabled and the user is not old
  310. // enough then reject this event and send them an informative message.
  311. if u != nil && data.GetChatEstbalishedUsersOnlyMode() && time.Since(event.client.User.CreatedAt) < config.GetDefaults().ChatEstablishedUserModeTimeDuration && !u.IsModerator() {
  312. s.sendActionToClient(c, "You have not been an established chat participant long enough to take part in chat. Please enjoy the stream and try again later.")
  313. return
  314. }
  315. var typecheck map[string]interface{}
  316. if err := json.Unmarshal(event.data, &typecheck); err != nil {
  317. log.Debugln(err)
  318. }
  319. eventType := typecheck["type"]
  320. switch eventType {
  321. case events.MessageSent:
  322. s.userMessageSent(event)
  323. case events.UserNameChanged:
  324. s.userNameChanged(event)
  325. case events.UserColorChanged:
  326. s.userColorChanged(event)
  327. default:
  328. log.Debugln(logSanitize(fmt.Sprint(eventType)), "event not found:", logSanitize(fmt.Sprint(typecheck)))
  329. }
  330. }
  331. func (s *Server) sendWelcomeMessageToClient(c *Client) {
  332. // Add an artificial delay so people notice this message come in.
  333. time.Sleep(7 * time.Second)
  334. welcomeMessage := utils.RenderSimpleMarkdown(data.GetServerWelcomeMessage())
  335. if welcomeMessage != "" {
  336. s.sendSystemMessageToClient(c, welcomeMessage)
  337. }
  338. }
  339. func (s *Server) sendAllWelcomeMessage() {
  340. welcomeMessage := utils.RenderSimpleMarkdown(data.GetServerWelcomeMessage())
  341. if welcomeMessage != "" {
  342. clientMessage := events.SystemMessageEvent{
  343. Event: events.Event{},
  344. MessageEvent: events.MessageEvent{
  345. Body: welcomeMessage,
  346. },
  347. }
  348. clientMessage.SetDefaults()
  349. _ = s.Broadcast(clientMessage.GetBroadcastPayload())
  350. }
  351. }
  352. func (s *Server) sendSystemMessageToClient(c *Client, message string) {
  353. clientMessage := events.SystemMessageEvent{
  354. Event: events.Event{},
  355. MessageEvent: events.MessageEvent{
  356. Body: message,
  357. },
  358. }
  359. clientMessage.SetDefaults()
  360. clientMessage.RenderBody()
  361. s.Send(clientMessage.GetBroadcastPayload(), c)
  362. }
  363. func (s *Server) sendActionToClient(c *Client, message string) {
  364. clientMessage := events.ActionEvent{
  365. MessageEvent: events.MessageEvent{
  366. Body: message,
  367. },
  368. Event: events.Event{
  369. Type: events.ChatActionSent,
  370. },
  371. }
  372. clientMessage.SetDefaults()
  373. clientMessage.RenderBody()
  374. s.Send(clientMessage.GetBroadcastPayload(), c)
  375. }