rest.go 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501
  1. // Copyright 2021 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package master
  15. import (
  16. "bufio"
  17. "context"
  18. "encoding/json"
  19. "fmt"
  20. "io"
  21. "net/http"
  22. "os"
  23. "reflect"
  24. "strconv"
  25. "strings"
  26. "time"
  27. "github.com/araddon/dateparse"
  28. mapset "github.com/deckarep/golang-set/v2"
  29. restfulspec "github.com/emicklei/go-restful-openapi/v2"
  30. "github.com/emicklei/go-restful/v3"
  31. "github.com/gorilla/securecookie"
  32. _ "github.com/gorse-io/dashboard"
  33. "github.com/juju/errors"
  34. "github.com/mitchellh/mapstructure"
  35. "github.com/rakyll/statik/fs"
  36. "github.com/samber/lo"
  37. "github.com/zhenghaoz/gorse/base"
  38. "github.com/zhenghaoz/gorse/base/encoding"
  39. "github.com/zhenghaoz/gorse/base/log"
  40. "github.com/zhenghaoz/gorse/base/progress"
  41. "github.com/zhenghaoz/gorse/cmd/version"
  42. "github.com/zhenghaoz/gorse/config"
  43. "github.com/zhenghaoz/gorse/model/click"
  44. "github.com/zhenghaoz/gorse/model/ranking"
  45. "github.com/zhenghaoz/gorse/server"
  46. "github.com/zhenghaoz/gorse/storage/cache"
  47. "github.com/zhenghaoz/gorse/storage/data"
  48. "go.uber.org/zap"
  49. )
  50. func (m *Master) CreateWebService() {
  51. ws := m.WebService
  52. ws.Consumes(restful.MIME_JSON).Produces(restful.MIME_JSON)
  53. ws.Path("/api/")
  54. ws.Filter(m.LoginFilter)
  55. ws.Route(ws.GET("/dashboard/cluster").To(m.getCluster).
  56. Doc("Get nodes in the cluster.").
  57. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  58. Returns(http.StatusOK, "OK", []Node{}).
  59. Writes([]Node{}))
  60. ws.Route(ws.GET("/dashboard/categories").To(m.getCategories).
  61. Doc("Get categories of items.").
  62. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  63. Returns(http.StatusOK, "OK", []string{}).
  64. Writes([]string{}))
  65. ws.Route(ws.GET("/dashboard/config").To(m.getConfig).
  66. Doc("Get config.").
  67. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  68. Returns(http.StatusOK, "OK", config.Config{}).
  69. Writes(config.Config{}))
  70. ws.Route(ws.GET("/dashboard/stats").To(m.getStats).
  71. Doc("Get global status.").
  72. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  73. Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")).
  74. Returns(http.StatusOK, "OK", Status{}).
  75. Writes(Status{}))
  76. ws.Route(ws.GET("/dashboard/tasks").To(m.getTasks).
  77. Doc("Get tasks.").
  78. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  79. Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")).
  80. Returns(http.StatusOK, "OK", []progress.Progress{}).
  81. Writes([]progress.Progress{}))
  82. ws.Route(ws.GET("/dashboard/rates").To(m.getRates).
  83. Doc("Get positive feedback rates.").
  84. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  85. Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")).
  86. Returns(http.StatusOK, "OK", map[string][]cache.TimeSeriesPoint{}).
  87. Writes(map[string][]cache.TimeSeriesPoint{}))
  88. // Get a user
  89. ws.Route(ws.GET("/dashboard/user/{user-id}").To(m.getUser).
  90. Doc("Get a user.").
  91. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  92. Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
  93. Returns(http.StatusOK, "OK", User{}).
  94. Writes(User{}))
  95. // Get a user feedback
  96. ws.Route(ws.GET("/dashboard/user/{user-id}/feedback/{feedback-type}").To(m.getTypedFeedbackByUser).
  97. Doc("Get feedback by user id with feedback type.").
  98. Metadata(restfulspec.KeyOpenAPITags, []string{"feedback"}).
  99. Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")).
  100. Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
  101. Param(ws.PathParameter("feedback-type", "feedback type").DataType("string")).
  102. Returns(http.StatusOK, "OK", []Feedback{}).
  103. Writes([]Feedback{}))
  104. // Get users
  105. ws.Route(ws.GET("/dashboard/users").To(m.getUsers).
  106. Doc("Get users.").
  107. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  108. Param(ws.QueryParameter("n", "number of returned users").DataType("int")).
  109. Param(ws.QueryParameter("cursor", "cursor for next page").DataType("string")).
  110. Returns(http.StatusOK, "OK", UserIterator{}).
  111. Writes(UserIterator{}))
  112. // Get popular items
  113. ws.Route(ws.GET("/dashboard/popular/").To(m.getPopular).
  114. Doc("get popular items").
  115. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  116. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  117. Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
  118. Returns(http.StatusOK, "OK", []ScoredItem{}).
  119. Writes([]ScoredItem{}))
  120. ws.Route(ws.GET("/dashboard/popular/{category}").To(m.getPopular).
  121. Doc("get popular items").
  122. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  123. Param(ws.PathParameter("category", "category of items").DataType("string")).
  124. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  125. Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
  126. Returns(http.StatusOK, "OK", []ScoredItem{}).
  127. Writes([]ScoredItem{}))
  128. // Get latest items
  129. ws.Route(ws.GET("/dashboard/latest/").To(m.getLatest).
  130. Doc("get latest items").
  131. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  132. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  133. Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
  134. Returns(http.StatusOK, "OK", []ScoredItem{}).
  135. Writes([]ScoredItem{}))
  136. ws.Route(ws.GET("/dashboard/latest/{category}").To(m.getLatest).
  137. Doc("get latest items").
  138. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  139. Param(ws.PathParameter("category", "category of items").DataType("string")).
  140. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  141. Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
  142. Returns(http.StatusOK, "OK", []ScoredItem{}).
  143. Writes([]ScoredItem{}))
  144. ws.Route(ws.GET("/dashboard/recommend/{user-id}").To(m.getRecommend).
  145. Doc("Get recommendation for user.").
  146. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  147. Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
  148. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  149. Returns(http.StatusOK, "OK", []data.Item{}).
  150. Writes([]data.Item{}))
  151. ws.Route(ws.GET("/dashboard/recommend/{user-id}/{recommender}").To(m.getRecommend).
  152. Doc("Get recommendation for user.").
  153. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  154. Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
  155. Param(ws.PathParameter("recommender", "one of `final`, `collaborative`, `user_based` and `item_based`").DataType("string")).
  156. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  157. Returns(http.StatusOK, "OK", []data.Item{}).
  158. Writes([]data.Item{}))
  159. ws.Route(ws.GET("/dashboard/recommend/{user-id}/{recommender}/{category}").To(m.getRecommend).
  160. Doc("Get recommendation for user.").
  161. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  162. Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
  163. Param(ws.PathParameter("recommender", "one of `final`, `collaborative`, `user_based` and `item_based`").DataType("string")).
  164. Param(ws.PathParameter("category", "category of items").DataType("string")).
  165. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  166. Returns(http.StatusOK, "OK", []data.Item{}).
  167. Writes([]data.Item{}))
  168. ws.Route(ws.GET("/dashboard/item/{item-id}/neighbors").To(m.getItemNeighbors).
  169. Doc("get neighbors of a item").
  170. Metadata(restfulspec.KeyOpenAPITags, []string{"recommendation"}).
  171. Param(ws.PathParameter("item-id", "identifier of the item").DataType("string")).
  172. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  173. Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
  174. Returns(http.StatusOK, "OK", []ScoredItem{}).
  175. Writes([]ScoredItem{}))
  176. ws.Route(ws.GET("/dashboard/item/{item-id}/neighbors/{category}").To(m.getItemCategorizedNeighbors).
  177. Doc("get neighbors of a item").
  178. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  179. Param(ws.PathParameter("item-id", "identifier of the item").DataType("string")).
  180. Param(ws.PathParameter("category", "category of items").DataType("string")).
  181. Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
  182. Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
  183. Returns(http.StatusOK, "OK", []ScoredItem{}).
  184. Writes([]ScoredItem{}))
  185. ws.Route(ws.GET("/dashboard/user/{user-id}/neighbors").To(m.getUserNeighbors).
  186. Doc("get neighbors of a user").
  187. Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
  188. Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
  189. Param(ws.QueryParameter("n", "number of returned users").DataType("int")).
  190. Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
  191. Returns(http.StatusOK, "OK", []ScoreUser{}).
  192. Writes([]ScoreUser{}))
  193. }
  194. // SinglePageAppFileSystem is the file system for single page app.
  195. type SinglePageAppFileSystem struct {
  196. root http.FileSystem
  197. }
  198. // Open index.html if required file not exists.
  199. func (fs *SinglePageAppFileSystem) Open(name string) (http.File, error) {
  200. f, err := fs.root.Open(name)
  201. if os.IsNotExist(err) {
  202. return fs.root.Open("/index.html")
  203. }
  204. return f, err
  205. }
  206. func (m *Master) SetOneMode(workerScheduleHandler http.HandlerFunc) {
  207. m.workerScheduleHandler = workerScheduleHandler
  208. }
  209. func (m *Master) StartHttpServer() {
  210. m.CreateWebService()
  211. container := restful.NewContainer()
  212. container.Handle("/", http.HandlerFunc(m.dashboard))
  213. container.Handle("/login", http.HandlerFunc(m.login))
  214. container.Handle("/logout", http.HandlerFunc(m.logout))
  215. container.Handle("/api/purge", http.HandlerFunc(m.purge))
  216. container.Handle("/api/bulk/users", http.HandlerFunc(m.importExportUsers))
  217. container.Handle("/api/bulk/items", http.HandlerFunc(m.importExportItems))
  218. container.Handle("/api/bulk/feedback", http.HandlerFunc(m.importExportFeedback))
  219. if m.workerScheduleHandler == nil {
  220. container.Handle("/api/admin/schedule", http.HandlerFunc(m.scheduleAPIHandler))
  221. } else {
  222. container.Handle("/api/admin/schedule/master", http.HandlerFunc(m.scheduleAPIHandler))
  223. container.Handle("/api/admin/schedule/worker", m.workerScheduleHandler)
  224. }
  225. m.RestServer.StartHttpServer(container)
  226. }
  227. var (
  228. cookieHandler = securecookie.New(
  229. securecookie.GenerateRandomKey(64),
  230. securecookie.GenerateRandomKey(32))
  231. staticFileSystem http.FileSystem
  232. staticFileServer http.Handler
  233. )
  234. func init() {
  235. var err error
  236. staticFileSystem, err = fs.New()
  237. if err != nil {
  238. log.Logger().Fatal("failed to load statik files", zap.Error(err))
  239. }
  240. staticFileServer = http.FileServer(&SinglePageAppFileSystem{staticFileSystem})
  241. // Create temporary directory if not exist
  242. tempDir := os.TempDir()
  243. if err = os.MkdirAll(tempDir, 1777); err != nil {
  244. log.Logger().Fatal("failed to create temporary directory", zap.String("directory", tempDir), zap.Error(err))
  245. }
  246. }
  247. // Taken from https://github.com/mytrile/nocache
  248. var noCacheHeaders = map[string]string{
  249. "Expires": time.Unix(0, 0).Format(time.RFC1123),
  250. "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
  251. "Pragma": "no-cache",
  252. "X-Accel-Expires": "0",
  253. }
  254. var etagHeaders = []string{
  255. "ETag",
  256. "If-Modified-Since",
  257. "If-Match",
  258. "If-None-Match",
  259. "If-Range",
  260. "If-Unmodified-Since",
  261. }
  262. // noCache is a simple piece of middleware that sets a number of HTTP headers to prevent
  263. // a router (or subrouter) from being cached by an upstream proxy and/or client.
  264. //
  265. // As per http://wiki.nginx.org/HttpProxyModule - noCache sets:
  266. //
  267. // Expires: Thu, 01 Jan 1970 00:00:00 UTC
  268. // Cache-Control: no-cache, private, max-age=0
  269. // X-Accel-Expires: 0
  270. // Pragma: no-cache (for HTTP/1.0 proxies/clients)
  271. func noCache(h http.Handler) http.Handler {
  272. fn := func(w http.ResponseWriter, r *http.Request) {
  273. // Delete any ETag headers that may have been set
  274. for _, v := range etagHeaders {
  275. if r.Header.Get(v) != "" {
  276. r.Header.Del(v)
  277. }
  278. }
  279. // Set our noCache headers
  280. for k, v := range noCacheHeaders {
  281. w.Header().Set(k, v)
  282. }
  283. h.ServeHTTP(w, r)
  284. }
  285. return http.HandlerFunc(fn)
  286. }
  287. func (m *Master) dashboard(response http.ResponseWriter, request *http.Request) {
  288. _, err := staticFileSystem.Open(request.RequestURI)
  289. if request.RequestURI == "/" || os.IsNotExist(err) {
  290. if !m.checkLogin(request) {
  291. http.Redirect(response, request, "/login", http.StatusFound)
  292. log.Logger().Info(fmt.Sprintf("%s %s", request.Method, request.URL), zap.Int("status_code", http.StatusFound))
  293. return
  294. }
  295. noCache(staticFileServer).ServeHTTP(response, request)
  296. return
  297. }
  298. staticFileServer.ServeHTTP(response, request)
  299. }
  300. func (m *Master) checkToken(token string) (bool, error) {
  301. resp, err := http.Get(fmt.Sprintf("%s/auth/dashboard/%s", m.Config.Master.DashboardAuthServer, token))
  302. if err != nil {
  303. return false, errors.Trace(err)
  304. }
  305. if resp.StatusCode == http.StatusOK {
  306. return true, nil
  307. } else if resp.StatusCode == http.StatusUnauthorized {
  308. return false, nil
  309. } else {
  310. if message, err := io.ReadAll(resp.Body); err != nil {
  311. return false, errors.Trace(err)
  312. } else {
  313. return false, errors.New(string(message))
  314. }
  315. }
  316. }
  317. func (m *Master) login(response http.ResponseWriter, request *http.Request) {
  318. switch request.Method {
  319. case http.MethodGet:
  320. log.Logger().Info("GET /login", zap.Int("status_code", http.StatusOK))
  321. staticFileServer.ServeHTTP(response, request)
  322. case http.MethodPost:
  323. token := request.FormValue("token")
  324. name := request.FormValue("user_name")
  325. pass := request.FormValue("password")
  326. if m.Config.Master.DashboardAuthServer != "" {
  327. // check access token
  328. if isValid, err := m.checkToken(token); err != nil {
  329. server.InternalServerError(restful.NewResponse(response), err)
  330. return
  331. } else if !isValid {
  332. http.Redirect(response, request, "login?msg=incorrect", http.StatusFound)
  333. log.Logger().Info("POST /login", zap.Int("status_code", http.StatusUnauthorized))
  334. return
  335. }
  336. // save token to cache
  337. if encoded, err := cookieHandler.Encode("token", token); err != nil {
  338. server.InternalServerError(restful.NewResponse(response), err)
  339. return
  340. } else {
  341. cookie := &http.Cookie{
  342. Name: "token",
  343. Value: encoded,
  344. Path: "/",
  345. }
  346. http.SetCookie(response, cookie)
  347. http.Redirect(response, request, "/", http.StatusFound)
  348. log.Logger().Info("POST /login", zap.Int("status_code", http.StatusUnauthorized))
  349. return
  350. }
  351. } else if m.Config.Master.DashboardUserName != "" || m.Config.Master.DashboardPassword != "" {
  352. if name != m.Config.Master.DashboardUserName || pass != m.Config.Master.DashboardPassword {
  353. http.Redirect(response, request, "login?msg=incorrect", http.StatusFound)
  354. log.Logger().Info("POST /login", zap.Int("status_code", http.StatusUnauthorized))
  355. return
  356. }
  357. value := map[string]string{
  358. "user_name": name,
  359. "password": pass,
  360. }
  361. if encoded, err := cookieHandler.Encode("session", value); err != nil {
  362. server.InternalServerError(restful.NewResponse(response), err)
  363. return
  364. } else {
  365. cookie := &http.Cookie{
  366. Name: "session",
  367. Value: encoded,
  368. Path: "/",
  369. }
  370. http.SetCookie(response, cookie)
  371. http.Redirect(response, request, "/", http.StatusFound)
  372. log.Logger().Info("POST /login", zap.Int("status_code", http.StatusFound))
  373. return
  374. }
  375. } else {
  376. http.Redirect(response, request, "/", http.StatusFound)
  377. log.Logger().Info("POST /login", zap.Int("status_code", http.StatusFound))
  378. }
  379. default:
  380. server.BadRequest(restful.NewResponse(response), errors.New("unsupported method"))
  381. }
  382. }
  383. func (m *Master) logout(response http.ResponseWriter, request *http.Request) {
  384. cookie := &http.Cookie{
  385. Name: "session",
  386. Value: "",
  387. Path: "/",
  388. MaxAge: -1,
  389. }
  390. http.SetCookie(response, cookie)
  391. http.Redirect(response, request, "/login", http.StatusFound)
  392. log.Logger().Info(fmt.Sprintf("%s %s", request.Method, request.RequestURI), zap.Int("status_code", http.StatusFound))
  393. }
  394. func (m *Master) LoginFilter(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
  395. if m.checkLogin(req.Request) {
  396. req.Request.Header.Set("X-API-Key", m.Config.Server.APIKey)
  397. chain.ProcessFilter(req, resp)
  398. } else if !strings.HasPrefix(req.SelectedRoutePath(), "/api/dashboard") {
  399. chain.ProcessFilter(req, resp)
  400. } else {
  401. if err := resp.WriteError(http.StatusUnauthorized, fmt.Errorf("unauthorized")); err != nil {
  402. log.ResponseLogger(resp).Error("failed to write error", zap.Error(err))
  403. }
  404. }
  405. }
  406. func (m *Master) checkLogin(request *http.Request) bool {
  407. if m.Config.Master.AdminAPIKey != "" && m.Config.Master.AdminAPIKey == request.Header.Get("X-Api-Key") {
  408. return true
  409. }
  410. if m.Config.Master.DashboardAuthServer != "" {
  411. if tokenCookie, err := request.Cookie("token"); err == nil {
  412. var token string
  413. if err = cookieHandler.Decode("token", tokenCookie.Value, &token); err == nil {
  414. if isValid, err := m.checkToken(token); err != nil {
  415. log.Logger().Error("failed to check access token", zap.Error(err))
  416. } else if isValid {
  417. return true
  418. }
  419. }
  420. }
  421. return false
  422. } else if m.Config.Master.DashboardUserName != "" || m.Config.Master.DashboardPassword != "" {
  423. if sessionCookie, err := request.Cookie("session"); err == nil {
  424. cookieValue := make(map[string]string)
  425. if err = cookieHandler.Decode("session", sessionCookie.Value, &cookieValue); err == nil {
  426. userName := cookieValue["user_name"]
  427. password := cookieValue["password"]
  428. if userName == m.Config.Master.DashboardUserName && password == m.Config.Master.DashboardPassword {
  429. return true
  430. }
  431. }
  432. }
  433. return false
  434. }
  435. return true
  436. }
  437. func (m *Master) getCategories(request *restful.Request, response *restful.Response) {
  438. ctx := context.Background()
  439. if request != nil && request.Request != nil {
  440. ctx = request.Request.Context()
  441. }
  442. categories, err := m.CacheClient.GetSet(ctx, cache.ItemCategories)
  443. if err != nil {
  444. server.InternalServerError(response, err)
  445. return
  446. }
  447. server.Ok(response, categories)
  448. }
  449. func (m *Master) getCluster(_ *restful.Request, response *restful.Response) {
  450. // collect nodes
  451. workers := make([]*Node, 0)
  452. servers := make([]*Node, 0)
  453. m.nodesInfoMutex.RLock()
  454. for _, info := range m.nodesInfo {
  455. switch info.Type {
  456. case WorkerNode:
  457. workers = append(workers, info)
  458. case ServerNode:
  459. servers = append(servers, info)
  460. }
  461. }
  462. m.nodesInfoMutex.RUnlock()
  463. // return nodes
  464. nodes := make([]*Node, 0)
  465. nodes = append(nodes, workers...)
  466. nodes = append(nodes, servers...)
  467. server.Ok(response, nodes)
  468. }
  469. func formatConfig(configMap map[string]interface{}) map[string]interface{} {
  470. return lo.MapValues(configMap, func(v interface{}, _ string) interface{} {
  471. switch value := v.(type) {
  472. case time.Duration:
  473. s := value.String()
  474. if strings.HasSuffix(s, "m0s") {
  475. s = s[:len(s)-2]
  476. }
  477. if strings.HasSuffix(s, "h0m") {
  478. s = s[:len(s)-2]
  479. }
  480. return s
  481. case map[string]interface{}:
  482. return formatConfig(value)
  483. default:
  484. return v
  485. }
  486. })
  487. }
  488. func (m *Master) getConfig(_ *restful.Request, response *restful.Response) {
  489. var configMap map[string]interface{}
  490. err := mapstructure.Decode(m.Config, &configMap)
  491. if err != nil {
  492. server.InternalServerError(response, err)
  493. return
  494. }
  495. if m.Config.Master.DashboardRedacted {
  496. delete(configMap, "database")
  497. }
  498. server.Ok(response, formatConfig(configMap))
  499. }
  500. type Status struct {
  501. BinaryVersion string
  502. NumServers int
  503. NumWorkers int
  504. NumUsers int
  505. NumItems int
  506. NumUserLabels int
  507. NumItemLabels int
  508. NumTotalPosFeedback int
  509. NumValidPosFeedback int
  510. NumValidNegFeedback int
  511. PopularItemsUpdateTime time.Time
  512. LatestItemsUpdateTime time.Time
  513. MatchingModelFitTime time.Time
  514. MatchingModelScore ranking.Score
  515. RankingModelFitTime time.Time
  516. RankingModelScore click.Score
  517. UserNeighborIndexRecall float32
  518. ItemNeighborIndexRecall float32
  519. MatchingIndexRecall float32
  520. }
  521. func (m *Master) getStats(request *restful.Request, response *restful.Response) {
  522. ctx := context.Background()
  523. if request != nil && request.Request != nil {
  524. ctx = request.Request.Context()
  525. }
  526. status := Status{BinaryVersion: version.Version}
  527. var err error
  528. // read number of users
  529. if status.NumUsers, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumUsers)).Integer(); err != nil {
  530. log.ResponseLogger(response).Warn("failed to get number of users", zap.Error(err))
  531. }
  532. // read number of items
  533. if status.NumItems, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumItems)).Integer(); err != nil {
  534. log.ResponseLogger(response).Warn("failed to get number of items", zap.Error(err))
  535. }
  536. // read number of user labels
  537. if status.NumUserLabels, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumUserLabels)).Integer(); err != nil {
  538. log.ResponseLogger(response).Warn("failed to get number of user labels", zap.Error(err))
  539. }
  540. // read number of item labels
  541. if status.NumItemLabels, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumItemLabels)).Integer(); err != nil {
  542. log.ResponseLogger(response).Warn("failed to get number of item labels", zap.Error(err))
  543. }
  544. // read number of total positive feedback
  545. if status.NumTotalPosFeedback, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumTotalPosFeedbacks)).Integer(); err != nil {
  546. log.ResponseLogger(response).Warn("failed to get number of total positive feedbacks", zap.Error(err))
  547. }
  548. // read number of valid positive feedback
  549. if status.NumValidPosFeedback, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumValidPosFeedbacks)).Integer(); err != nil {
  550. log.ResponseLogger(response).Warn("failed to get number of valid positive feedbacks", zap.Error(err))
  551. }
  552. // read number of valid negative feedback
  553. if status.NumValidNegFeedback, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.NumValidNegFeedbacks)).Integer(); err != nil {
  554. log.ResponseLogger(response).Warn("failed to get number of valid negative feedbacks", zap.Error(err))
  555. }
  556. // count the number of workers and servers
  557. m.nodesInfoMutex.Lock()
  558. for _, node := range m.nodesInfo {
  559. switch node.Type {
  560. case ServerNode:
  561. status.NumServers++
  562. case WorkerNode:
  563. status.NumWorkers++
  564. }
  565. }
  566. m.nodesInfoMutex.Unlock()
  567. // read popular items update time
  568. if status.PopularItemsUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastUpdatePopularItemsTime)).Time(); err != nil {
  569. log.ResponseLogger(response).Warn("failed to get popular items update time", zap.Error(err))
  570. }
  571. // read the latest items update time
  572. if status.LatestItemsUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastUpdateLatestItemsTime)).Time(); err != nil {
  573. log.ResponseLogger(response).Warn("failed to get latest items update time", zap.Error(err))
  574. }
  575. status.MatchingModelScore = m.rankingScore
  576. status.RankingModelScore = m.clickScore
  577. // read last fit matching model time
  578. if status.MatchingModelFitTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastFitMatchingModelTime)).Time(); err != nil {
  579. log.ResponseLogger(response).Warn("failed to get last fit matching model time", zap.Error(err))
  580. }
  581. // read last fit ranking model time
  582. if status.RankingModelFitTime, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.LastFitRankingModelTime)).Time(); err != nil {
  583. log.ResponseLogger(response).Warn("failed to get last fit ranking model time", zap.Error(err))
  584. }
  585. // read user neighbor index recall
  586. var temp string
  587. if m.Config.Recommend.UserNeighbors.EnableIndex {
  588. if temp, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.UserNeighborIndexRecall)).String(); err != nil {
  589. log.ResponseLogger(response).Warn("failed to get user neighbor index recall", zap.Error(err))
  590. } else {
  591. status.UserNeighborIndexRecall = encoding.ParseFloat32(temp)
  592. }
  593. }
  594. // read item neighbor index recall
  595. if m.Config.Recommend.ItemNeighbors.EnableIndex {
  596. if temp, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.ItemNeighborIndexRecall)).String(); err != nil {
  597. log.ResponseLogger(response).Warn("failed to get item neighbor index recall", zap.Error(err))
  598. } else {
  599. status.ItemNeighborIndexRecall = encoding.ParseFloat32(temp)
  600. }
  601. }
  602. // read matching index recall
  603. if m.Config.Recommend.Collaborative.EnableIndex {
  604. if temp, err = m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.MatchingIndexRecall)).String(); err != nil {
  605. log.ResponseLogger(response).Warn("failed to get matching index recall", zap.Error(err))
  606. } else {
  607. status.MatchingIndexRecall = encoding.ParseFloat32(temp)
  608. }
  609. }
  610. server.Ok(response, status)
  611. }
  612. func (m *Master) getTasks(_ *restful.Request, response *restful.Response) {
  613. // List workers
  614. workers := mapset.NewSet[string]()
  615. m.nodesInfoMutex.RLock()
  616. for _, info := range m.nodesInfo {
  617. if info.Type == WorkerNode {
  618. workers.Add(info.Name)
  619. }
  620. }
  621. m.nodesInfoMutex.RUnlock()
  622. // List local progress
  623. progressList := m.tracer.List()
  624. // list remote progress
  625. m.remoteProgress.Range(func(key, value interface{}) bool {
  626. if workers.Contains(key.(string)) {
  627. progressList = append(progressList, value.([]progress.Progress)...)
  628. }
  629. return true
  630. })
  631. server.Ok(response, progressList)
  632. }
  633. func (m *Master) getRates(request *restful.Request, response *restful.Response) {
  634. ctx := context.Background()
  635. if request != nil && request.Request != nil {
  636. ctx = request.Request.Context()
  637. }
  638. // Parse parameters
  639. n, err := server.ParseInt(request, "n", 100)
  640. if err != nil {
  641. server.BadRequest(response, err)
  642. return
  643. }
  644. measurements := make(map[string][]cache.TimeSeriesPoint, len(m.Config.Recommend.DataSource.PositiveFeedbackTypes))
  645. for _, feedbackType := range m.Config.Recommend.DataSource.PositiveFeedbackTypes {
  646. measurements[feedbackType], err = m.CacheClient.GetTimeSeriesPoints(ctx, cache.Key(PositiveFeedbackRate, feedbackType),
  647. time.Now().Add(-24*time.Hour*time.Duration(n)), time.Now())
  648. if err != nil {
  649. server.InternalServerError(response, err)
  650. return
  651. }
  652. }
  653. server.Ok(response, measurements)
  654. }
  655. type UserIterator struct {
  656. Cursor string
  657. Users []User
  658. }
  659. type User struct {
  660. data.User
  661. LastActiveTime time.Time
  662. LastUpdateTime time.Time
  663. }
  664. func (m *Master) getUser(request *restful.Request, response *restful.Response) {
  665. ctx := context.Background()
  666. if request != nil && request.Request != nil {
  667. ctx = request.Request.Context()
  668. }
  669. // get user id
  670. userId := request.PathParameter("user-id")
  671. // get user
  672. user, err := m.DataClient.GetUser(ctx, userId)
  673. if err != nil {
  674. if errors.Is(err, errors.NotFound) {
  675. server.PageNotFound(response, err)
  676. } else {
  677. server.InternalServerError(response, err)
  678. }
  679. return
  680. }
  681. detail := User{User: user}
  682. if detail.LastActiveTime, err = m.CacheClient.Get(ctx, cache.Key(cache.LastModifyUserTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) {
  683. server.InternalServerError(response, err)
  684. return
  685. }
  686. if detail.LastUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.LastUpdateUserRecommendTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) {
  687. server.InternalServerError(response, err)
  688. return
  689. }
  690. server.Ok(response, detail)
  691. }
  692. func (m *Master) getUsers(request *restful.Request, response *restful.Response) {
  693. ctx := context.Background()
  694. if request != nil && request.Request != nil {
  695. ctx = request.Request.Context()
  696. }
  697. // Authorize
  698. cursor := request.QueryParameter("cursor")
  699. n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN)
  700. if err != nil {
  701. server.BadRequest(response, err)
  702. return
  703. }
  704. // get all users
  705. cursor, users, err := m.DataClient.GetUsers(ctx, cursor, n)
  706. if err != nil {
  707. server.InternalServerError(response, err)
  708. return
  709. }
  710. details := make([]User, len(users))
  711. for i, user := range users {
  712. details[i].User = user
  713. if details[i].LastActiveTime, err = m.CacheClient.Get(ctx, cache.Key(cache.LastModifyUserTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) {
  714. server.InternalServerError(response, err)
  715. return
  716. }
  717. if details[i].LastUpdateTime, err = m.CacheClient.Get(ctx, cache.Key(cache.LastUpdateUserRecommendTime, user.UserId)).Time(); err != nil && !errors.Is(err, errors.NotFound) {
  718. server.InternalServerError(response, err)
  719. return
  720. }
  721. }
  722. server.Ok(response, UserIterator{Cursor: cursor, Users: details})
  723. }
  724. func (m *Master) getRecommend(request *restful.Request, response *restful.Response) {
  725. ctx := context.Background()
  726. if request != nil && request.Request != nil {
  727. ctx = request.Request.Context()
  728. }
  729. // parse arguments
  730. recommender := request.PathParameter("recommender")
  731. userId := request.PathParameter("user-id")
  732. categories := []string{request.PathParameter("category")}
  733. n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN)
  734. if err != nil {
  735. server.BadRequest(response, err)
  736. return
  737. }
  738. var results []string
  739. switch recommender {
  740. case "offline":
  741. results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendOffline)
  742. case "collaborative":
  743. results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendCollaborative)
  744. case "user_based":
  745. results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendUserBased)
  746. case "item_based":
  747. results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendItemBased)
  748. case "_":
  749. recommenders := []server.Recommender{m.RecommendOffline}
  750. for _, recommender := range m.Config.Recommend.Online.FallbackRecommend {
  751. switch recommender {
  752. case "collaborative":
  753. recommenders = append(recommenders, m.RecommendCollaborative)
  754. case "item_based":
  755. recommenders = append(recommenders, m.RecommendItemBased)
  756. case "user_based":
  757. recommenders = append(recommenders, m.RecommendUserBased)
  758. case "latest":
  759. recommenders = append(recommenders, m.RecommendLatest)
  760. case "popular":
  761. recommenders = append(recommenders, m.RecommendPopular)
  762. default:
  763. server.InternalServerError(response, fmt.Errorf("unknown fallback recommendation method `%s`", recommender))
  764. return
  765. }
  766. }
  767. results, err = m.Recommend(ctx, response, userId, categories, n, recommenders...)
  768. }
  769. if err != nil {
  770. server.InternalServerError(response, err)
  771. return
  772. }
  773. // Send result
  774. details := make([]data.Item, len(results))
  775. for i := range results {
  776. details[i], err = m.DataClient.GetItem(ctx, results[i])
  777. if err != nil {
  778. server.InternalServerError(response, err)
  779. return
  780. }
  781. }
  782. server.Ok(response, details)
  783. }
  784. type Feedback struct {
  785. FeedbackType string
  786. UserId string
  787. Item data.Item
  788. Timestamp time.Time
  789. Comment string
  790. }
  791. // get feedback by user-id with feedback type
  792. func (m *Master) getTypedFeedbackByUser(request *restful.Request, response *restful.Response) {
  793. ctx := context.Background()
  794. if request != nil && request.Request != nil {
  795. ctx = request.Request.Context()
  796. }
  797. feedbackType := request.PathParameter("feedback-type")
  798. userId := request.PathParameter("user-id")
  799. feedback, err := m.DataClient.GetUserFeedback(ctx, userId, m.Config.Now(), feedbackType)
  800. if err != nil {
  801. server.InternalServerError(response, err)
  802. return
  803. }
  804. details := make([]Feedback, len(feedback))
  805. for i := range feedback {
  806. details[i].FeedbackType = feedback[i].FeedbackType
  807. details[i].UserId = feedback[i].UserId
  808. details[i].Timestamp = feedback[i].Timestamp
  809. details[i].Comment = feedback[i].Comment
  810. details[i].Item, err = m.DataClient.GetItem(ctx, feedback[i].ItemId)
  811. if errors.Is(err, errors.NotFound) {
  812. details[i].Item = data.Item{ItemId: feedback[i].ItemId, Comment: "** This item doesn't exist in Gorse **"}
  813. } else if err != nil {
  814. server.InternalServerError(response, err)
  815. return
  816. }
  817. }
  818. server.Ok(response, details)
  819. }
  820. type ScoredItem struct {
  821. data.Item
  822. Score float64
  823. }
  824. type ScoreUser struct {
  825. data.User
  826. Score float64
  827. }
  828. func (m *Master) searchDocuments(collection, subset, category string, request *restful.Request, response *restful.Response, retType interface{}) {
  829. ctx := context.Background()
  830. if request != nil && request.Request != nil {
  831. ctx = request.Request.Context()
  832. }
  833. var n, offset int
  834. var err error
  835. // read arguments
  836. if offset, err = server.ParseInt(request, "offset", 0); err != nil {
  837. server.BadRequest(response, err)
  838. return
  839. }
  840. if n, err = server.ParseInt(request, "n", m.Config.Server.DefaultN); err != nil {
  841. server.BadRequest(response, err)
  842. return
  843. }
  844. // Get the popular list
  845. scores, err := m.CacheClient.SearchDocuments(ctx, collection, subset, []string{category}, offset, m.Config.Recommend.CacheSize)
  846. if err != nil {
  847. server.InternalServerError(response, err)
  848. return
  849. }
  850. if n > 0 && len(scores) > n {
  851. scores = scores[:n]
  852. }
  853. // Send result
  854. switch retType.(type) {
  855. case data.Item:
  856. details := make([]ScoredItem, len(scores))
  857. for i := range scores {
  858. details[i].Score = scores[i].Score
  859. details[i].Item, err = m.DataClient.GetItem(ctx, scores[i].Id)
  860. if err != nil {
  861. server.InternalServerError(response, err)
  862. return
  863. }
  864. }
  865. server.Ok(response, details)
  866. case data.User:
  867. details := make([]ScoreUser, len(scores))
  868. for i := range scores {
  869. details[i].Score = scores[i].Score
  870. details[i].User, err = m.DataClient.GetUser(ctx, scores[i].Id)
  871. if err != nil {
  872. server.InternalServerError(response, err)
  873. return
  874. }
  875. }
  876. server.Ok(response, details)
  877. default:
  878. log.ResponseLogger(response).Fatal("unknown return type", zap.Any("ret_type", reflect.TypeOf(retType)))
  879. }
  880. }
  881. // getPopular gets popular items from database.
  882. func (m *Master) getPopular(request *restful.Request, response *restful.Response) {
  883. category := request.PathParameter("category")
  884. m.searchDocuments(cache.PopularItems, "", category, request, response, data.Item{})
  885. }
  886. func (m *Master) getLatest(request *restful.Request, response *restful.Response) {
  887. category := request.PathParameter("category")
  888. m.searchDocuments(cache.LatestItems, "", category, request, response, data.Item{})
  889. }
  890. func (m *Master) getItemNeighbors(request *restful.Request, response *restful.Response) {
  891. itemId := request.PathParameter("item-id")
  892. m.searchDocuments(cache.ItemNeighbors, itemId, "", request, response, data.Item{})
  893. }
  894. func (m *Master) getItemCategorizedNeighbors(request *restful.Request, response *restful.Response) {
  895. itemId := request.PathParameter("item-id")
  896. category := request.PathParameter("category")
  897. m.searchDocuments(cache.ItemNeighbors, itemId, category, request, response, data.Item{})
  898. }
  899. func (m *Master) getUserNeighbors(request *restful.Request, response *restful.Response) {
  900. userId := request.PathParameter("user-id")
  901. m.searchDocuments(cache.UserNeighbors, userId, "", request, response, data.User{})
  902. }
  903. func (m *Master) importExportUsers(response http.ResponseWriter, request *http.Request) {
  904. ctx := context.Background()
  905. if request != nil {
  906. ctx = request.Context()
  907. }
  908. if !m.checkLogin(request) {
  909. resp := restful.NewResponse(response)
  910. err := resp.WriteErrorString(http.StatusUnauthorized, "unauthorized")
  911. if err != nil {
  912. server.InternalServerError(resp, err)
  913. return
  914. }
  915. return
  916. }
  917. switch request.Method {
  918. case http.MethodGet:
  919. var err error
  920. response.Header().Set("Content-Type", "text/csv")
  921. response.Header().Set("Content-Disposition", "attachment;filename=users.csv")
  922. // write header
  923. if _, err = response.Write([]byte("user_id,labels\r\n")); err != nil {
  924. server.InternalServerError(restful.NewResponse(response), err)
  925. return
  926. }
  927. // write rows
  928. userChan, errChan := m.DataClient.GetUserStream(ctx, batchSize)
  929. for users := range userChan {
  930. for _, user := range users {
  931. labels, err := json.Marshal(user.Labels)
  932. if err != nil {
  933. server.InternalServerError(restful.NewResponse(response), err)
  934. return
  935. }
  936. if _, err = response.Write([]byte(fmt.Sprintf("%s,%s\r\n",
  937. base.Escape(user.UserId), base.Escape(string(labels))))); err != nil {
  938. server.InternalServerError(restful.NewResponse(response), err)
  939. return
  940. }
  941. }
  942. }
  943. if err = <-errChan; err != nil {
  944. server.InternalServerError(restful.NewResponse(response), errors.Trace(err))
  945. return
  946. }
  947. case http.MethodPost:
  948. hasHeader := formValue(request, "has-header", "true") == "true"
  949. sep := formValue(request, "sep", ",")
  950. // field separator must be a single character
  951. if len(sep) != 1 {
  952. server.BadRequest(restful.NewResponse(response), fmt.Errorf("field separator must be a single character"))
  953. return
  954. }
  955. labelSep := formValue(request, "label-sep", "|")
  956. fmtString := formValue(request, "format", "ul")
  957. file, _, err := request.FormFile("file")
  958. if err != nil {
  959. server.BadRequest(restful.NewResponse(response), err)
  960. return
  961. }
  962. defer file.Close()
  963. m.importUsers(ctx, response, file, hasHeader, sep, labelSep, fmtString)
  964. }
  965. }
  966. func (m *Master) importUsers(ctx context.Context, response http.ResponseWriter, file io.Reader, hasHeader bool, sep, labelSep, fmtString string) {
  967. lineCount := 0
  968. timeStart := time.Now()
  969. users := make([]data.User, 0)
  970. err := base.ReadLines(bufio.NewScanner(file), sep, func(lineNumber int, splits []string) bool {
  971. var err error
  972. // skip header
  973. if hasHeader {
  974. hasHeader = false
  975. return true
  976. }
  977. splits, err = format(fmtString, "ul", splits, lineNumber)
  978. if err != nil {
  979. server.BadRequest(restful.NewResponse(response), err)
  980. return false
  981. }
  982. // 1. user id
  983. if err = base.ValidateId(splits[0]); err != nil {
  984. server.BadRequest(restful.NewResponse(response),
  985. fmt.Errorf("invalid user id `%v` at line %d (%s)", splits[0], lineNumber, err.Error()))
  986. return false
  987. }
  988. user := data.User{UserId: splits[0]}
  989. // 2. labels
  990. if splits[1] != "" {
  991. var labels any
  992. if err = json.Unmarshal([]byte(splits[1]), &labels); err != nil {
  993. server.BadRequest(restful.NewResponse(response),
  994. fmt.Errorf("invalid labels `%v` at line %d (%s)", splits[1], lineNumber, err.Error()))
  995. return false
  996. }
  997. user.Labels = labels
  998. }
  999. users = append(users, user)
  1000. // batch insert
  1001. if len(users) == batchSize {
  1002. err = m.DataClient.BatchInsertUsers(ctx, users)
  1003. if err != nil {
  1004. server.InternalServerError(restful.NewResponse(response), err)
  1005. return false
  1006. }
  1007. users = nil
  1008. }
  1009. lineCount++
  1010. return true
  1011. })
  1012. if err != nil {
  1013. server.BadRequest(restful.NewResponse(response), err)
  1014. return
  1015. }
  1016. if len(users) > 0 {
  1017. err = m.DataClient.BatchInsertUsers(ctx, users)
  1018. if err != nil {
  1019. server.InternalServerError(restful.NewResponse(response), err)
  1020. return
  1021. }
  1022. }
  1023. m.notifyDataImported()
  1024. timeUsed := time.Since(timeStart)
  1025. log.Logger().Info("complete import users",
  1026. zap.Duration("time_used", timeUsed),
  1027. zap.Int("num_users", lineCount))
  1028. server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount})
  1029. }
  1030. func (m *Master) importExportItems(response http.ResponseWriter, request *http.Request) {
  1031. ctx := context.Background()
  1032. if request != nil {
  1033. ctx = request.Context()
  1034. }
  1035. if !m.checkLogin(request) {
  1036. resp := restful.NewResponse(response)
  1037. err := resp.WriteErrorString(http.StatusUnauthorized, "unauthorized")
  1038. if err != nil {
  1039. server.InternalServerError(resp, err)
  1040. return
  1041. }
  1042. return
  1043. }
  1044. switch request.Method {
  1045. case http.MethodGet:
  1046. var err error
  1047. response.Header().Set("Content-Type", "text/csv")
  1048. response.Header().Set("Content-Disposition", "attachment;filename=items.csv")
  1049. // write header
  1050. if _, err = response.Write([]byte("item_id,is_hidden,categories,time_stamp,labels,description\r\n")); err != nil {
  1051. server.InternalServerError(restful.NewResponse(response), err)
  1052. return
  1053. }
  1054. // write rows
  1055. itemChan, errChan := m.DataClient.GetItemStream(ctx, batchSize, nil)
  1056. for items := range itemChan {
  1057. for _, item := range items {
  1058. labels, err := json.Marshal(item.Labels)
  1059. if err != nil {
  1060. server.InternalServerError(restful.NewResponse(response), err)
  1061. return
  1062. }
  1063. if _, err = response.Write([]byte(fmt.Sprintf("%s,%t,%s,%v,%s,%s\r\n",
  1064. base.Escape(item.ItemId), item.IsHidden, base.Escape(strings.Join(item.Categories, "|")),
  1065. item.Timestamp, base.Escape(string(labels)), base.Escape(item.Comment)))); err != nil {
  1066. server.InternalServerError(restful.NewResponse(response), err)
  1067. return
  1068. }
  1069. }
  1070. }
  1071. if err = <-errChan; err != nil {
  1072. server.InternalServerError(restful.NewResponse(response), errors.Trace(err))
  1073. return
  1074. }
  1075. case http.MethodPost:
  1076. hasHeader := formValue(request, "has-header", "true") == "true"
  1077. sep := formValue(request, "sep", ",")
  1078. // field separator must be a single character
  1079. if len(sep) != 1 {
  1080. server.BadRequest(restful.NewResponse(response), fmt.Errorf("field separator must be a single character"))
  1081. return
  1082. }
  1083. labelSep := formValue(request, "label-sep", "|")
  1084. fmtString := formValue(request, "format", "ihctld")
  1085. file, _, err := request.FormFile("file")
  1086. if err != nil {
  1087. server.BadRequest(restful.NewResponse(response), err)
  1088. return
  1089. }
  1090. defer file.Close()
  1091. m.importItems(ctx, response, file, hasHeader, sep, labelSep, fmtString)
  1092. default:
  1093. writeError(response, http.StatusMethodNotAllowed, "method not allowed")
  1094. }
  1095. }
  1096. func (m *Master) importItems(ctx context.Context, response http.ResponseWriter, file io.Reader, hasHeader bool, sep, labelSep, fmtString string) {
  1097. lineCount := 0
  1098. timeStart := time.Now()
  1099. items := make([]data.Item, 0)
  1100. err := base.ReadLines(bufio.NewScanner(file), sep, func(lineNumber int, splits []string) bool {
  1101. var err error
  1102. // skip header
  1103. if hasHeader {
  1104. hasHeader = false
  1105. return true
  1106. }
  1107. splits, err = format(fmtString, "ihctld", splits, lineNumber)
  1108. if err != nil {
  1109. server.BadRequest(restful.NewResponse(response), err)
  1110. return false
  1111. }
  1112. // 1. item id
  1113. if err = base.ValidateId(splits[0]); err != nil {
  1114. server.BadRequest(restful.NewResponse(response),
  1115. fmt.Errorf("invalid item id `%v` at line %d (%s)", splits[0], lineNumber, err.Error()))
  1116. return false
  1117. }
  1118. item := data.Item{ItemId: splits[0]}
  1119. // 2. hidden
  1120. if splits[1] != "" {
  1121. item.IsHidden, err = strconv.ParseBool(splits[1])
  1122. if err != nil {
  1123. server.BadRequest(restful.NewResponse(response),
  1124. fmt.Errorf("invalid hidden value `%v` at line %d (%s)", splits[1], lineNumber, err.Error()))
  1125. return false
  1126. }
  1127. }
  1128. // 3. categories
  1129. if splits[2] != "" {
  1130. item.Categories = strings.Split(splits[2], labelSep)
  1131. for _, category := range item.Categories {
  1132. if err = base.ValidateId(category); err != nil {
  1133. server.BadRequest(restful.NewResponse(response),
  1134. fmt.Errorf("invalid category `%v` at line %d (%s)", category, lineNumber, err.Error()))
  1135. return false
  1136. }
  1137. }
  1138. }
  1139. // 4. timestamp
  1140. if splits[3] != "" {
  1141. item.Timestamp, err = dateparse.ParseAny(splits[3])
  1142. if err != nil {
  1143. server.BadRequest(restful.NewResponse(response),
  1144. fmt.Errorf("failed to parse datetime `%v` at line %v", splits[1], lineNumber))
  1145. return false
  1146. }
  1147. }
  1148. // 5. labels
  1149. if splits[4] != "" {
  1150. var labels any
  1151. if err = json.Unmarshal([]byte(splits[4]), &labels); err != nil {
  1152. server.BadRequest(restful.NewResponse(response),
  1153. fmt.Errorf("failed to parse labels `%v` at line %v", splits[4], lineNumber))
  1154. return false
  1155. }
  1156. item.Labels = labels
  1157. }
  1158. // 6. comment
  1159. item.Comment = splits[5]
  1160. items = append(items, item)
  1161. // batch insert
  1162. if len(items) == batchSize {
  1163. err = m.DataClient.BatchInsertItems(ctx, items)
  1164. if err != nil {
  1165. server.InternalServerError(restful.NewResponse(response), err)
  1166. return false
  1167. }
  1168. items = nil
  1169. }
  1170. lineCount++
  1171. return true
  1172. })
  1173. if err != nil {
  1174. server.BadRequest(restful.NewResponse(response), err)
  1175. return
  1176. }
  1177. if len(items) > 0 {
  1178. err = m.DataClient.BatchInsertItems(ctx, items)
  1179. if err != nil {
  1180. server.InternalServerError(restful.NewResponse(response), err)
  1181. return
  1182. }
  1183. }
  1184. m.notifyDataImported()
  1185. timeUsed := time.Since(timeStart)
  1186. log.Logger().Info("complete import items",
  1187. zap.Duration("time_used", timeUsed),
  1188. zap.Int("num_items", lineCount))
  1189. server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount})
  1190. }
  1191. func format(inFmt, outFmt string, s []string, lineCount int) ([]string, error) {
  1192. if len(s) < len(inFmt) {
  1193. log.Logger().Error("number of fields mismatch",
  1194. zap.Int("expect", len(inFmt)),
  1195. zap.Int("actual", len(s)))
  1196. return nil, fmt.Errorf("number of fields mismatch at line %v", lineCount)
  1197. }
  1198. if inFmt == outFmt {
  1199. return s, nil
  1200. }
  1201. pool := make(map[uint8]string)
  1202. for i := range inFmt {
  1203. pool[inFmt[i]] = s[i]
  1204. }
  1205. out := make([]string, len(outFmt))
  1206. for i, c := range outFmt {
  1207. out[i] = pool[uint8(c)]
  1208. }
  1209. return out, nil
  1210. }
  1211. func formValue(request *http.Request, fieldName, defaultValue string) string {
  1212. value := request.FormValue(fieldName)
  1213. if value == "" {
  1214. return defaultValue
  1215. }
  1216. return value
  1217. }
  1218. func (m *Master) importExportFeedback(response http.ResponseWriter, request *http.Request) {
  1219. ctx := context.Background()
  1220. if request != nil {
  1221. ctx = request.Context()
  1222. }
  1223. if !m.checkLogin(request) {
  1224. writeError(response, http.StatusUnauthorized, "unauthorized")
  1225. return
  1226. }
  1227. switch request.Method {
  1228. case http.MethodGet:
  1229. var err error
  1230. response.Header().Set("Content-Type", "text/csv")
  1231. response.Header().Set("Content-Disposition", "attachment;filename=feedback.csv")
  1232. // write header
  1233. if _, err = response.Write([]byte("feedback_type,user_id,item_id,time_stamp\r\n")); err != nil {
  1234. server.InternalServerError(restful.NewResponse(response), err)
  1235. return
  1236. }
  1237. // write rows
  1238. feedbackChan, errChan := m.DataClient.GetFeedbackStream(ctx, batchSize, data.WithEndTime(*m.Config.Now()))
  1239. for feedback := range feedbackChan {
  1240. for _, v := range feedback {
  1241. if _, err = response.Write([]byte(fmt.Sprintf("%s,%s,%s,%v\r\n",
  1242. base.Escape(v.FeedbackType), base.Escape(v.UserId), base.Escape(v.ItemId), v.Timestamp))); err != nil {
  1243. server.InternalServerError(restful.NewResponse(response), err)
  1244. return
  1245. }
  1246. }
  1247. }
  1248. if err = <-errChan; err != nil {
  1249. server.InternalServerError(restful.NewResponse(response), errors.Trace(err))
  1250. return
  1251. }
  1252. case http.MethodPost:
  1253. hasHeader := formValue(request, "has-header", "true") == "true"
  1254. sep := formValue(request, "sep", ",")
  1255. // field separator must be a single character
  1256. if len(sep) != 1 {
  1257. server.BadRequest(restful.NewResponse(response), fmt.Errorf("field separator must be a single character"))
  1258. return
  1259. }
  1260. fmtString := formValue(request, "format", "fuit")
  1261. // import items
  1262. file, _, err := request.FormFile("file")
  1263. if err != nil {
  1264. server.BadRequest(restful.NewResponse(response), err)
  1265. return
  1266. }
  1267. defer file.Close()
  1268. m.importFeedback(ctx, response, file, hasHeader, sep, fmtString)
  1269. default:
  1270. writeError(response, http.StatusMethodNotAllowed, "method not allowed")
  1271. }
  1272. }
  1273. func (m *Master) importFeedback(ctx context.Context, response http.ResponseWriter, file io.Reader, hasHeader bool, sep, fmtString string) {
  1274. var err error
  1275. scanner := bufio.NewScanner(file)
  1276. lineCount := 0
  1277. timeStart := time.Now()
  1278. feedbacks := make([]data.Feedback, 0)
  1279. err = base.ReadLines(scanner, sep, func(lineNumber int, splits []string) bool {
  1280. if hasHeader {
  1281. hasHeader = false
  1282. return true
  1283. }
  1284. // reorder fields
  1285. splits, err = format(fmtString, "fuit", splits, lineNumber)
  1286. if err != nil {
  1287. server.BadRequest(restful.NewResponse(response), err)
  1288. return false
  1289. }
  1290. feedback := data.Feedback{}
  1291. // 1. feedback type
  1292. feedback.FeedbackType = splits[0]
  1293. if err = base.ValidateId(splits[0]); err != nil {
  1294. server.BadRequest(restful.NewResponse(response),
  1295. fmt.Errorf("invalid feedback type `%v` at line %d (%s)", splits[0], lineNumber, err.Error()))
  1296. return false
  1297. }
  1298. // 2. user id
  1299. if err = base.ValidateId(splits[1]); err != nil {
  1300. server.BadRequest(restful.NewResponse(response),
  1301. fmt.Errorf("invalid user id `%v` at line %d (%s)", splits[1], lineNumber, err.Error()))
  1302. return false
  1303. }
  1304. feedback.UserId = splits[1]
  1305. // 3. item id
  1306. if err = base.ValidateId(splits[2]); err != nil {
  1307. server.BadRequest(restful.NewResponse(response),
  1308. fmt.Errorf("invalid item id `%v` at line %d (%s)", splits[2], lineNumber, err.Error()))
  1309. return false
  1310. }
  1311. feedback.ItemId = splits[2]
  1312. feedback.Timestamp, err = dateparse.ParseAny(splits[3])
  1313. if err != nil {
  1314. server.BadRequest(restful.NewResponse(response),
  1315. fmt.Errorf("failed to parse datetime `%v` at line %d", splits[3], lineNumber))
  1316. return false
  1317. }
  1318. feedbacks = append(feedbacks, feedback)
  1319. // batch insert
  1320. if len(feedbacks) == batchSize {
  1321. // batch insert to data store
  1322. err = m.DataClient.BatchInsertFeedback(ctx, feedbacks,
  1323. m.Config.Server.AutoInsertUser,
  1324. m.Config.Server.AutoInsertItem, true)
  1325. if err != nil {
  1326. server.InternalServerError(restful.NewResponse(response), err)
  1327. return false
  1328. }
  1329. feedbacks = nil
  1330. }
  1331. lineCount++
  1332. return true
  1333. })
  1334. if err != nil {
  1335. server.BadRequest(restful.NewResponse(response), err)
  1336. return
  1337. }
  1338. // insert to cache store
  1339. if len(feedbacks) > 0 {
  1340. // insert to data store
  1341. err = m.DataClient.BatchInsertFeedback(ctx, feedbacks,
  1342. m.Config.Server.AutoInsertUser,
  1343. m.Config.Server.AutoInsertItem, true)
  1344. if err != nil {
  1345. server.InternalServerError(restful.NewResponse(response), err)
  1346. return
  1347. }
  1348. }
  1349. m.notifyDataImported()
  1350. timeUsed := time.Since(timeStart)
  1351. log.Logger().Info("complete import feedback",
  1352. zap.Duration("time_used", timeUsed),
  1353. zap.Int("num_items", lineCount))
  1354. server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount})
  1355. }
  1356. var checkList = mapset.NewSet("delete_users", "delete_items", "delete_feedback", "delete_cache")
  1357. func (m *Master) purge(response http.ResponseWriter, request *http.Request) {
  1358. // check method
  1359. if request.Method != http.MethodPost {
  1360. writeError(response, http.StatusMethodNotAllowed, "method not allowed")
  1361. return
  1362. }
  1363. // check login
  1364. if !m.checkLogin(request) {
  1365. resp := restful.NewResponse(response)
  1366. err := resp.WriteErrorString(http.StatusUnauthorized, "unauthorized")
  1367. if err != nil {
  1368. server.InternalServerError(resp, err)
  1369. return
  1370. }
  1371. return
  1372. }
  1373. // check password
  1374. if m.Config.Master.DashboardPassword == "" {
  1375. writeError(response, http.StatusUnauthorized, "purge is not allowed without dashboard password")
  1376. return
  1377. }
  1378. // check list
  1379. if err := request.ParseForm(); err != nil {
  1380. server.BadRequest(restful.NewResponse(response), err)
  1381. return
  1382. }
  1383. checkedList := strings.Split(request.Form.Get("check_list"), ",")
  1384. if !checkList.Equal(mapset.NewSet(checkedList...)) {
  1385. writeError(response, http.StatusUnauthorized, "please confirm by checking all")
  1386. return
  1387. }
  1388. // purge data
  1389. if err := m.DataClient.Purge(); err != nil {
  1390. writeError(response, http.StatusInternalServerError, err.Error())
  1391. return
  1392. }
  1393. if err := m.CacheClient.Purge(); err != nil {
  1394. writeError(response, http.StatusInternalServerError, err.Error())
  1395. return
  1396. }
  1397. }
  1398. func (m *Master) scheduleAPIHandler(writer http.ResponseWriter, request *http.Request) {
  1399. if !m.checkAdmin(request) {
  1400. writeError(writer, http.StatusUnauthorized, "unauthorized")
  1401. return
  1402. }
  1403. switch request.Method {
  1404. case http.MethodGet:
  1405. writer.WriteHeader(http.StatusOK)
  1406. bytes, err := json.Marshal(m.scheduleState)
  1407. if err != nil {
  1408. writeError(writer, http.StatusInternalServerError, err.Error())
  1409. }
  1410. if _, err = writer.Write(bytes); err != nil {
  1411. writeError(writer, http.StatusInternalServerError, err.Error())
  1412. }
  1413. case http.MethodPost:
  1414. s := request.FormValue("search_model")
  1415. if s != "" {
  1416. if searchModel, err := strconv.ParseBool(s); err != nil {
  1417. writeError(writer, http.StatusBadRequest, err.Error())
  1418. } else {
  1419. m.scheduleState.SearchModel = searchModel
  1420. }
  1421. }
  1422. m.triggerChan.Signal()
  1423. default:
  1424. writeError(writer, http.StatusMethodNotAllowed, "method not allowed")
  1425. }
  1426. }
  1427. func writeError(response http.ResponseWriter, httpStatus int, message string) {
  1428. log.Logger().Error(strings.ToLower(http.StatusText(httpStatus)), zap.String("error", message))
  1429. response.Header().Set("Access-Control-Allow-Origin", "*")
  1430. response.WriteHeader(httpStatus)
  1431. if _, err := response.Write([]byte(message)); err != nil {
  1432. log.Logger().Error("failed to write error", zap.Error(err))
  1433. }
  1434. }
  1435. func (s *Master) checkAdmin(request *http.Request) bool {
  1436. if s.Config.Master.AdminAPIKey == "" {
  1437. return true
  1438. }
  1439. if request.FormValue("X-API-Key") == s.Config.Master.AdminAPIKey {
  1440. return true
  1441. }
  1442. return false
  1443. }