sql.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835
  1. // Copyright 2021 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package data
  15. import (
  16. "context"
  17. "database/sql"
  18. "encoding/base64"
  19. "fmt"
  20. "time"
  21. mapset "github.com/deckarep/golang-set/v2"
  22. _ "github.com/go-sql-driver/mysql"
  23. "github.com/juju/errors"
  24. _ "github.com/lib/pq"
  25. "github.com/samber/lo"
  26. "github.com/zhenghaoz/gorse/base/jsonutil"
  27. "github.com/zhenghaoz/gorse/base/log"
  28. "github.com/zhenghaoz/gorse/storage"
  29. "gorm.io/gorm"
  30. "gorm.io/gorm/clause"
  31. _ "modernc.org/sqlite"
  32. )
  33. const bufSize = 1
  34. type SQLDriver int
  35. const (
  36. MySQL SQLDriver = iota
  37. Postgres
  38. SQLite
  39. )
  40. type SQLItem struct {
  41. ItemId string `gorm:"column:item_id;primaryKey"`
  42. IsHidden bool `gorm:"column:is_hidden"`
  43. Categories string `gorm:"column:categories"`
  44. Timestamp time.Time `gorm:"column:time_stamp"`
  45. Labels string `gorm:"column:labels"`
  46. Comment string `gorm:"column:comment"`
  47. }
  48. func NewSQLItem(item Item) (sqlItem SQLItem) {
  49. var buf []byte
  50. sqlItem.ItemId = item.ItemId
  51. sqlItem.IsHidden = item.IsHidden
  52. buf, _ = jsonutil.Marshal(item.Categories)
  53. sqlItem.Categories = string(buf)
  54. sqlItem.Timestamp = item.Timestamp
  55. buf, _ = jsonutil.Marshal(item.Labels)
  56. sqlItem.Labels = string(buf)
  57. sqlItem.Comment = item.Comment
  58. return
  59. }
  60. type SQLUser struct {
  61. UserId string `gorm:"column:user_id;primaryKey"`
  62. Labels string `gorm:"column:labels"`
  63. Subscribe string `gorm:"column:subscribe"`
  64. Comment string `gorm:"column:comment"`
  65. }
  66. func NewSQLUser(user User) (sqlUser SQLUser) {
  67. var buf []byte
  68. sqlUser.UserId = user.UserId
  69. buf, _ = jsonutil.Marshal(user.Labels)
  70. sqlUser.Labels = string(buf)
  71. buf, _ = jsonutil.Marshal(user.Subscribe)
  72. sqlUser.Subscribe = string(buf)
  73. sqlUser.Comment = user.Comment
  74. return
  75. }
  76. // SQLDatabase use MySQL as data storage.
  77. type SQLDatabase struct {
  78. storage.TablePrefix
  79. gormDB *gorm.DB
  80. client *sql.DB
  81. driver SQLDriver
  82. }
  83. // Init tables and indices in MySQL.
  84. func (d *SQLDatabase) Init() error {
  85. switch d.driver {
  86. case MySQL:
  87. // create tables
  88. type Items struct {
  89. ItemId string `gorm:"column:item_id;type:varchar(256) not null;primaryKey"`
  90. IsHidden bool `gorm:"column:is_hidden;type:bool;not null"`
  91. Categories []string `gorm:"column:categories;type:json;not null"`
  92. Timestamp time.Time `gorm:"column:time_stamp;type:datetime;not null"`
  93. Labels []string `gorm:"column:labels;type:json;not null"`
  94. Comment string `gorm:"column:comment;type:text;not null"`
  95. }
  96. type Users struct {
  97. UserId string `gorm:"column:user_id;type:varchar(256);not null;primaryKey"`
  98. Labels []string `gorm:"column:labels;type:json;not null"`
  99. Subscribe []string `gorm:"column:subscribe;type:json;not null"`
  100. Comment string `gorm:"column:comment;type:text;not null"`
  101. }
  102. type Feedback struct {
  103. FeedbackType string `gorm:"column:feedback_type;type:varchar(256);not null;primaryKey"`
  104. UserId string `gorm:"column:user_id;type:varchar(256);not null;primaryKey;index:user_id"`
  105. ItemId string `gorm:"column:item_id;type:varchar(256);not null;primaryKey;index:item_id"`
  106. Timestamp time.Time `gorm:"column:time_stamp;type:datetime;not null"`
  107. Comment string `gorm:"column:comment;type:text;not null"`
  108. }
  109. err := d.gormDB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(Users{}, Items{}, Feedback{})
  110. if err != nil {
  111. return errors.Trace(err)
  112. }
  113. case Postgres:
  114. // create tables
  115. type Items struct {
  116. ItemId string `gorm:"column:item_id;type:varchar(256);not null;primaryKey"`
  117. IsHidden bool `gorm:"column:is_hidden;type:bool;not null;default:false"`
  118. Categories string `gorm:"column:categories;type:json;not null;default:'[]'"`
  119. Timestamp time.Time `gorm:"column:time_stamp;type:timestamptz;not null"`
  120. Labels string `gorm:"column:labels;type:json;not null;default:'[]'"`
  121. Comment string `gorm:"column:comment;type:text;not null;default:''"`
  122. }
  123. type Users struct {
  124. UserId string `gorm:"column:user_id;type:varchar(256) not null;primaryKey"`
  125. Labels string `gorm:"column:labels;type:json;not null;default:'[]'"`
  126. Subscribe string `gorm:"column:subscribe;type:json;not null;default:'[]'"`
  127. Comment string `gorm:"column:comment;type:text;not null;default:''"`
  128. }
  129. type Feedback struct {
  130. FeedbackType string `gorm:"column:feedback_type;type:varchar(256);not null;primaryKey"`
  131. UserId string `gorm:"column:user_id;type:varchar(256);not null;primaryKey;index:user_id_index"`
  132. ItemId string `gorm:"column:item_id;type:varchar(256);not null;primaryKey;index:item_id_index"`
  133. Timestamp time.Time `gorm:"column:time_stamp;type:timestamptz;not null"`
  134. Comment string `gorm:"column:comment;type:text;not null;default:''"`
  135. }
  136. err := d.gormDB.AutoMigrate(Users{}, Items{}, Feedback{})
  137. if err != nil {
  138. return errors.Trace(err)
  139. }
  140. case SQLite:
  141. // create tables
  142. type Items struct {
  143. ItemId string `gorm:"column:item_id;type:varchar(256);not null;primaryKey"`
  144. IsHidden bool `gorm:"column:is_hidden;type:bool;not null;default:false"`
  145. Categories string `gorm:"column:categories;type:json;not null;default:'[]'"`
  146. Timestamp string `gorm:"column:time_stamp;type:datetime;not null;default:'0001-01-01'"`
  147. Labels string `gorm:"column:labels;type:json;not null;default:'[]'"`
  148. Comment string `gorm:"column:comment;type:text;not null;default:''"`
  149. }
  150. type Users struct {
  151. UserId string `gorm:"column:user_id;type:varchar(256) not null;primaryKey"`
  152. Labels string `gorm:"column:labels;type:json;not null;default:'null'"`
  153. Subscribe string `gorm:"column:subscribe;type:json;not null;default:'null'"`
  154. Comment string `gorm:"column:comment;type:text;not null;default:''"`
  155. }
  156. type Feedback struct {
  157. FeedbackType string `gorm:"column:feedback_type;type:varchar(256);not null;primaryKey"`
  158. UserId string `gorm:"column:user_id;type:varchar(256);not null;primaryKey;index:user_id_index"`
  159. ItemId string `gorm:"column:item_id;type:varchar(256);not null;primaryKey;index:item_id_index"`
  160. Timestamp string `gorm:"column:time_stamp;type:datetime;not null;default:'0001-01-01'"`
  161. Comment string `gorm:"column:comment;type:text;not null;default:''"`
  162. }
  163. err := d.gormDB.AutoMigrate(Users{}, Items{}, Feedback{})
  164. if err != nil {
  165. return errors.Trace(err)
  166. }
  167. }
  168. return nil
  169. }
  170. func (d *SQLDatabase) Ping() error {
  171. return d.client.Ping()
  172. }
  173. // Close MySQL connection.
  174. func (d *SQLDatabase) Close() error {
  175. return d.client.Close()
  176. }
  177. func (d *SQLDatabase) Purge() error {
  178. tables := []string{d.ItemsTable(), d.FeedbackTable(), d.UsersTable()}
  179. for _, tableName := range tables {
  180. err := d.gormDB.Exec(fmt.Sprintf("DELETE FROM %s", tableName)).Error
  181. if err != nil {
  182. return errors.Trace(err)
  183. }
  184. }
  185. return nil
  186. }
  187. // BatchInsertItems inserts a batch of items into MySQL.
  188. func (d *SQLDatabase) BatchInsertItems(ctx context.Context, items []Item) error {
  189. if len(items) == 0 {
  190. return nil
  191. }
  192. rows := make([]SQLItem, 0, len(items))
  193. memo := mapset.NewSet[string]()
  194. for _, item := range items {
  195. if !memo.Contains(item.ItemId) {
  196. memo.Add(item.ItemId)
  197. row := NewSQLItem(item)
  198. if d.driver == SQLite {
  199. row.Timestamp = row.Timestamp.In(time.UTC)
  200. }
  201. rows = append(rows, row)
  202. }
  203. }
  204. err := d.gormDB.WithContext(ctx).Clauses(clause.OnConflict{
  205. Columns: []clause.Column{{Name: "item_id"}},
  206. DoUpdates: clause.AssignmentColumns([]string{"is_hidden", "categories", "time_stamp", "labels", "comment"}),
  207. }).Create(rows).Error
  208. return errors.Trace(err)
  209. }
  210. func (d *SQLDatabase) BatchGetItems(ctx context.Context, itemIds []string) ([]Item, error) {
  211. if len(itemIds) == 0 {
  212. return nil, nil
  213. }
  214. result, err := d.gormDB.WithContext(ctx).Table(d.ItemsTable()).
  215. Select("item_id, is_hidden, categories, time_stamp, labels, comment").
  216. Where("item_id IN ?", itemIds).Rows()
  217. if err != nil {
  218. return nil, errors.Trace(err)
  219. }
  220. defer result.Close()
  221. var items []Item
  222. for result.Next() {
  223. var item Item
  224. if err = d.gormDB.ScanRows(result, &item); err != nil {
  225. return nil, errors.Trace(err)
  226. }
  227. items = append(items, item)
  228. }
  229. return items, nil
  230. }
  231. // DeleteItem deletes a item from MySQL.
  232. func (d *SQLDatabase) DeleteItem(ctx context.Context, itemId string) error {
  233. if err := d.gormDB.WithContext(ctx).Delete(&SQLItem{ItemId: itemId}).Error; err != nil {
  234. return errors.Trace(err)
  235. }
  236. if err := d.gormDB.WithContext(ctx).Delete(&Feedback{}, "item_id = ?", itemId).Error; err != nil {
  237. return errors.Trace(err)
  238. }
  239. return nil
  240. }
  241. // GetItem get a item from MySQL.
  242. func (d *SQLDatabase) GetItem(ctx context.Context, itemId string) (Item, error) {
  243. var result *sql.Rows
  244. var err error
  245. result, err = d.gormDB.WithContext(ctx).Table(d.ItemsTable()).Select("item_id, is_hidden, categories, time_stamp, labels, comment").Where("item_id = ?", itemId).Rows()
  246. if err != nil {
  247. return Item{}, errors.Trace(err)
  248. }
  249. defer result.Close()
  250. if result.Next() {
  251. var item Item
  252. if err = d.gormDB.ScanRows(result, &item); err != nil {
  253. return Item{}, errors.Trace(err)
  254. }
  255. return item, nil
  256. }
  257. return Item{}, errors.Annotate(ErrItemNotExist, itemId)
  258. }
  259. // ModifyItem modify an item in MySQL.
  260. func (d *SQLDatabase) ModifyItem(ctx context.Context, itemId string, patch ItemPatch) error {
  261. // ignore empty patch
  262. if patch.IsHidden == nil && patch.Categories == nil && patch.Labels == nil && patch.Comment == nil && patch.Timestamp == nil {
  263. log.Logger().Debug("empty item patch")
  264. return nil
  265. }
  266. attributes := make(map[string]any)
  267. if patch.IsHidden != nil {
  268. if *patch.IsHidden {
  269. attributes["is_hidden"] = 1
  270. } else {
  271. attributes["is_hidden"] = 0
  272. }
  273. }
  274. if patch.Categories != nil {
  275. text, _ := jsonutil.Marshal(patch.Categories)
  276. attributes["categories"] = string(text)
  277. }
  278. if patch.Comment != nil {
  279. attributes["comment"] = *patch.Comment
  280. }
  281. if patch.Labels != nil {
  282. text, _ := jsonutil.Marshal(patch.Labels)
  283. attributes["labels"] = string(text)
  284. }
  285. if patch.Timestamp != nil {
  286. switch d.driver {
  287. case SQLite:
  288. attributes["time_stamp"] = patch.Timestamp.In(time.UTC)
  289. default:
  290. attributes["time_stamp"] = patch.Timestamp
  291. }
  292. }
  293. err := d.gormDB.WithContext(ctx).Model(&SQLItem{ItemId: itemId}).Updates(attributes).Error
  294. return errors.Trace(err)
  295. }
  296. // GetItems returns items from MySQL.
  297. func (d *SQLDatabase) GetItems(ctx context.Context, cursor string, n int, timeLimit *time.Time) (string, []Item, error) {
  298. buf, err := base64.StdEncoding.DecodeString(cursor)
  299. if err != nil {
  300. return "", nil, errors.Trace(err)
  301. }
  302. cursorItem := string(buf)
  303. tx := d.gormDB.WithContext(ctx).Table(d.ItemsTable()).Select("item_id, is_hidden, categories, time_stamp, labels, comment")
  304. if cursorItem != "" {
  305. tx.Where("item_id >= ?", cursorItem)
  306. }
  307. if timeLimit != nil {
  308. tx.Where("time_stamp >= ?", *timeLimit)
  309. }
  310. result, err := tx.Order("item_id").Limit(n + 1).Rows()
  311. if err != nil {
  312. return "", nil, errors.Trace(err)
  313. }
  314. items := make([]Item, 0)
  315. defer result.Close()
  316. for result.Next() {
  317. var item Item
  318. if err = d.gormDB.ScanRows(result, &item); err != nil {
  319. return "", nil, errors.Trace(err)
  320. }
  321. items = append(items, item)
  322. }
  323. if len(items) == n+1 {
  324. return base64.StdEncoding.EncodeToString([]byte(items[len(items)-1].ItemId)), items[:len(items)-1], nil
  325. }
  326. return "", items, nil
  327. }
  328. // GetItemStream reads items by stream.
  329. func (d *SQLDatabase) GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error) {
  330. itemChan := make(chan []Item, bufSize)
  331. errChan := make(chan error, 1)
  332. go func() {
  333. defer close(itemChan)
  334. defer close(errChan)
  335. // send query
  336. tx := d.gormDB.WithContext(ctx).Table(d.ItemsTable()).Select("item_id, is_hidden, categories, time_stamp, labels, comment")
  337. if timeLimit != nil {
  338. tx.Where("time_stamp >= ?", *timeLimit)
  339. }
  340. result, err := tx.Rows()
  341. if err != nil {
  342. errChan <- errors.Trace(err)
  343. return
  344. }
  345. // fetch result
  346. items := make([]Item, 0, batchSize)
  347. defer result.Close()
  348. for result.Next() {
  349. var item Item
  350. if err = d.gormDB.ScanRows(result, &item); err != nil {
  351. errChan <- errors.Trace(err)
  352. return
  353. }
  354. items = append(items, item)
  355. if len(items) == batchSize {
  356. itemChan <- items
  357. items = make([]Item, 0, batchSize)
  358. }
  359. }
  360. if len(items) > 0 {
  361. itemChan <- items
  362. }
  363. errChan <- nil
  364. }()
  365. return itemChan, errChan
  366. }
  367. // GetItemFeedback returns feedback of a item from MySQL.
  368. func (d *SQLDatabase) GetItemFeedback(ctx context.Context, itemId string, feedbackTypes ...string) ([]Feedback, error) {
  369. tx := d.gormDB.WithContext(ctx).Table(d.FeedbackTable()).Select("user_id, item_id, feedback_type, time_stamp")
  370. switch d.driver {
  371. case SQLite:
  372. tx.Where("time_stamp <= DATETIME() AND item_id = ?", itemId)
  373. default:
  374. tx.Where("time_stamp <= NOW() AND item_id = ?", itemId)
  375. }
  376. if len(feedbackTypes) > 0 {
  377. tx.Where("feedback_type IN ?", feedbackTypes)
  378. }
  379. result, err := tx.Rows()
  380. if err != nil {
  381. return nil, errors.Trace(err)
  382. }
  383. feedbacks := make([]Feedback, 0)
  384. defer result.Close()
  385. for result.Next() {
  386. var feedback Feedback
  387. if err = d.gormDB.ScanRows(result, &feedback); err != nil {
  388. return nil, errors.Trace(err)
  389. }
  390. feedbacks = append(feedbacks, feedback)
  391. }
  392. return feedbacks, nil
  393. }
  394. // BatchInsertUsers inserts users into MySQL.
  395. func (d *SQLDatabase) BatchInsertUsers(ctx context.Context, users []User) error {
  396. if len(users) == 0 {
  397. return nil
  398. }
  399. rows := make([]SQLUser, 0, len(users))
  400. memo := mapset.NewSet[string]()
  401. for _, user := range users {
  402. if !memo.Contains(user.UserId) {
  403. memo.Add(user.UserId)
  404. rows = append(rows, NewSQLUser(user))
  405. }
  406. }
  407. err := d.gormDB.WithContext(ctx).Clauses(clause.OnConflict{
  408. Columns: []clause.Column{{Name: "user_id"}},
  409. DoUpdates: clause.AssignmentColumns([]string{"labels", "subscribe", "comment"}),
  410. }).Create(rows).Error
  411. return errors.Trace(err)
  412. }
  413. // DeleteUser deletes a user from MySQL.
  414. func (d *SQLDatabase) DeleteUser(ctx context.Context, userId string) error {
  415. if err := d.gormDB.WithContext(ctx).Delete(&SQLUser{UserId: userId}).Error; err != nil {
  416. return errors.Trace(err)
  417. }
  418. if err := d.gormDB.WithContext(ctx).Delete(&Feedback{}, "user_id = ?", userId).Error; err != nil {
  419. return errors.Trace(err)
  420. }
  421. return nil
  422. }
  423. // GetUser returns a user from MySQL.
  424. func (d *SQLDatabase) GetUser(ctx context.Context, userId string) (User, error) {
  425. var result *sql.Rows
  426. var err error
  427. result, err = d.gormDB.WithContext(ctx).Table(d.UsersTable()).
  428. Select("user_id, labels, subscribe, comment").
  429. Where("user_id = ?", userId).Rows()
  430. if err != nil {
  431. return User{}, errors.Trace(err)
  432. }
  433. defer result.Close()
  434. if result.Next() {
  435. var user User
  436. if err = d.gormDB.ScanRows(result, &user); err != nil {
  437. return User{}, errors.Trace(err)
  438. }
  439. return user, nil
  440. }
  441. return User{}, errors.Annotate(ErrUserNotExist, userId)
  442. }
  443. // ModifyUser modify a user in MySQL.
  444. func (d *SQLDatabase) ModifyUser(ctx context.Context, userId string, patch UserPatch) error {
  445. // ignore empty patch
  446. if patch.Labels == nil && patch.Subscribe == nil && patch.Comment == nil {
  447. log.Logger().Debug("empty user patch")
  448. return nil
  449. }
  450. attributes := make(map[string]any)
  451. if patch.Comment != nil {
  452. attributes["comment"] = *patch.Comment
  453. }
  454. if patch.Labels != nil {
  455. text, _ := jsonutil.Marshal(patch.Labels)
  456. attributes["labels"] = string(text)
  457. }
  458. if patch.Subscribe != nil {
  459. text, _ := jsonutil.Marshal(patch.Subscribe)
  460. attributes["subscribe"] = string(text)
  461. }
  462. err := d.gormDB.WithContext(ctx).Model(&SQLUser{UserId: userId}).Updates(attributes).Error
  463. return errors.Trace(err)
  464. }
  465. // GetUsers returns users from MySQL.
  466. func (d *SQLDatabase) GetUsers(ctx context.Context, cursor string, n int) (string, []User, error) {
  467. buf, err := base64.StdEncoding.DecodeString(cursor)
  468. if err != nil {
  469. return "", nil, errors.Trace(err)
  470. }
  471. cursorUser := string(buf)
  472. tx := d.gormDB.WithContext(ctx).Table(d.UsersTable()).Select("user_id, labels, subscribe, comment")
  473. if cursorUser != "" {
  474. tx.Where("user_id >= ?", cursorUser)
  475. }
  476. result, err := tx.Order("user_id").Limit(n + 1).Rows()
  477. if err != nil {
  478. return "", nil, errors.Trace(err)
  479. }
  480. users := make([]User, 0)
  481. defer result.Close()
  482. for result.Next() {
  483. var user User
  484. if err = d.gormDB.ScanRows(result, &user); err != nil {
  485. return "", nil, errors.Trace(err)
  486. }
  487. users = append(users, user)
  488. }
  489. if len(users) == n+1 {
  490. return base64.StdEncoding.EncodeToString([]byte(users[len(users)-1].UserId)), users[:len(users)-1], nil
  491. }
  492. return "", users, nil
  493. }
  494. // GetUserStream read users by stream.
  495. func (d *SQLDatabase) GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error) {
  496. userChan := make(chan []User, bufSize)
  497. errChan := make(chan error, 1)
  498. go func() {
  499. defer close(userChan)
  500. defer close(errChan)
  501. // send query
  502. result, err := d.gormDB.WithContext(ctx).Table(d.UsersTable()).Select("user_id, labels, subscribe, comment").Rows()
  503. if err != nil {
  504. errChan <- errors.Trace(err)
  505. return
  506. }
  507. // fetch result
  508. users := make([]User, 0, batchSize)
  509. defer result.Close()
  510. for result.Next() {
  511. var user User
  512. if err = d.gormDB.ScanRows(result, &user); err != nil {
  513. errChan <- errors.Trace(err)
  514. return
  515. }
  516. users = append(users, user)
  517. if len(users) == batchSize {
  518. userChan <- users
  519. users = make([]User, 0, batchSize)
  520. }
  521. }
  522. if len(users) > 0 {
  523. userChan <- users
  524. }
  525. errChan <- nil
  526. }()
  527. return userChan, errChan
  528. }
  529. // GetUserFeedback returns feedback of a user from MySQL.
  530. func (d *SQLDatabase) GetUserFeedback(ctx context.Context, userId string, endTime *time.Time, feedbackTypes ...string) ([]Feedback, error) {
  531. tx := d.gormDB.WithContext(ctx).Table(d.FeedbackTable()).
  532. Select("feedback_type, user_id, item_id, time_stamp, comment").
  533. Where("user_id = ?", userId)
  534. if endTime != nil {
  535. tx.Where("time_stamp <= ?", d.convertTimeZone(endTime))
  536. }
  537. if len(feedbackTypes) > 0 {
  538. tx.Where("feedback_type IN ?", feedbackTypes)
  539. }
  540. result, err := tx.Rows()
  541. if err != nil {
  542. return nil, errors.Trace(err)
  543. }
  544. feedbacks := make([]Feedback, 0)
  545. defer result.Close()
  546. for result.Next() {
  547. var feedback Feedback
  548. if err = d.gormDB.ScanRows(result, &feedback); err != nil {
  549. return nil, errors.Trace(err)
  550. }
  551. feedbacks = append(feedbacks, feedback)
  552. }
  553. return feedbacks, nil
  554. }
  555. // BatchInsertFeedback insert a batch feedback into MySQL.
  556. // If insertUser set, new users will be inserted to user table.
  557. // If insertItem set, new items will be inserted to item table.
  558. func (d *SQLDatabase) BatchInsertFeedback(ctx context.Context, feedback []Feedback, insertUser, insertItem, overwrite bool) error {
  559. tx := d.gormDB.WithContext(ctx)
  560. // skip empty list
  561. if len(feedback) == 0 {
  562. return nil
  563. }
  564. // collect users and items
  565. users := mapset.NewSet[string]()
  566. items := mapset.NewSet[string]()
  567. for _, v := range feedback {
  568. users.Add(v.UserId)
  569. items.Add(v.ItemId)
  570. }
  571. // insert users
  572. if insertUser {
  573. userList := users.ToSlice()
  574. err := tx.Clauses(clause.OnConflict{
  575. Columns: []clause.Column{{Name: "user_id"}},
  576. DoNothing: true,
  577. }).Create(lo.Map(userList, func(userId string, _ int) SQLUser {
  578. return SQLUser{
  579. UserId: userId,
  580. Labels: "null",
  581. Subscribe: "null",
  582. }
  583. })).Error
  584. if err != nil {
  585. return errors.Trace(err)
  586. }
  587. } else {
  588. for _, user := range users.ToSlice() {
  589. rs, err := tx.Table(d.UsersTable()).Select("user_id").Where("user_id = ?", user).Rows()
  590. if err != nil {
  591. return errors.Trace(err)
  592. } else if !rs.Next() {
  593. users.Remove(user)
  594. }
  595. if err = rs.Close(); err != nil {
  596. return errors.Trace(err)
  597. }
  598. }
  599. }
  600. // insert items
  601. if insertItem {
  602. itemList := items.ToSlice()
  603. err := tx.Clauses(clause.OnConflict{
  604. Columns: []clause.Column{{Name: "item_id"}},
  605. DoNothing: true,
  606. }).Create(lo.Map(itemList, func(itemId string, _ int) SQLItem {
  607. return SQLItem{
  608. ItemId: itemId,
  609. Labels: "null",
  610. Categories: "null",
  611. }
  612. })).Error
  613. if err != nil {
  614. return errors.Trace(err)
  615. }
  616. } else {
  617. for _, item := range items.ToSlice() {
  618. rs, err := tx.Table(d.ItemsTable()).Select("item_id").Where("item_id = ?", item).Rows()
  619. if err != nil {
  620. return errors.Trace(err)
  621. } else if !rs.Next() {
  622. items.Remove(item)
  623. }
  624. if err = rs.Close(); err != nil {
  625. return errors.Trace(err)
  626. }
  627. }
  628. }
  629. // insert feedback
  630. rows := make([]Feedback, 0, len(feedback))
  631. memo := make(map[lo.Tuple3[string, string, string]]struct{})
  632. for _, f := range feedback {
  633. if users.Contains(f.UserId) && items.Contains(f.ItemId) {
  634. if _, exist := memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}]; !exist {
  635. memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}] = struct{}{}
  636. if d.driver == SQLite {
  637. f.Timestamp = f.Timestamp.In(time.UTC)
  638. }
  639. rows = append(rows, f)
  640. }
  641. }
  642. }
  643. if len(rows) == 0 {
  644. return nil
  645. }
  646. err := tx.Clauses(clause.OnConflict{
  647. Columns: []clause.Column{{Name: "feedback_type"}, {Name: "user_id"}, {Name: "item_id"}},
  648. DoNothing: !overwrite,
  649. DoUpdates: lo.If(overwrite, clause.AssignmentColumns([]string{"time_stamp", "comment"})).Else(nil),
  650. }).Create(rows).Error
  651. return errors.Trace(err)
  652. }
  653. // GetFeedback returns feedback from MySQL.
  654. func (d *SQLDatabase) GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error) {
  655. buf, err := base64.StdEncoding.DecodeString(cursor)
  656. if err != nil {
  657. return "", nil, errors.Trace(err)
  658. }
  659. tx := d.gormDB.WithContext(ctx).Table(d.FeedbackTable()).Select("feedback_type, user_id, item_id, time_stamp, comment")
  660. if len(buf) > 0 {
  661. var cursorKey FeedbackKey
  662. if err := jsonutil.Unmarshal(buf, &cursorKey); err != nil {
  663. return "", nil, err
  664. }
  665. tx.Where("(feedback_type, user_id, item_id) >= (?,?,?)", cursorKey.FeedbackType, cursorKey.UserId, cursorKey.ItemId)
  666. }
  667. if len(feedbackTypes) > 0 {
  668. tx.Where("feedback_type IN ?", feedbackTypes)
  669. }
  670. if beginTime != nil {
  671. tx.Where("time_stamp >= ?", d.convertTimeZone(beginTime))
  672. }
  673. if endTime != nil {
  674. tx.Where("time_stamp <= ?", d.convertTimeZone(endTime))
  675. }
  676. tx.Order("feedback_type, user_id, item_id").Limit(n + 1)
  677. result, err := tx.Rows()
  678. if err != nil {
  679. return "", nil, errors.Trace(err)
  680. }
  681. feedbacks := make([]Feedback, 0)
  682. defer result.Close()
  683. for result.Next() {
  684. var feedback Feedback
  685. if err = d.gormDB.ScanRows(result, &feedback); err != nil {
  686. return "", nil, errors.Trace(err)
  687. }
  688. feedbacks = append(feedbacks, feedback)
  689. }
  690. if len(feedbacks) == n+1 {
  691. nextCursorKey := feedbacks[len(feedbacks)-1].FeedbackKey
  692. nextCursor, err := jsonutil.Marshal(nextCursorKey)
  693. if err != nil {
  694. return "", nil, errors.Trace(err)
  695. }
  696. return base64.StdEncoding.EncodeToString(nextCursor), feedbacks[:len(feedbacks)-1], nil
  697. }
  698. return "", feedbacks, nil
  699. }
  700. // GetFeedbackStream reads feedback by stream.
  701. func (d *SQLDatabase) GetFeedbackStream(ctx context.Context, batchSize int, scanOptions ...ScanOption) (chan []Feedback, chan error) {
  702. scan := NewScanOptions(scanOptions...)
  703. feedbackChan := make(chan []Feedback, bufSize)
  704. errChan := make(chan error, 1)
  705. go func() {
  706. defer close(feedbackChan)
  707. defer close(errChan)
  708. // send query
  709. tx := d.gormDB.WithContext(ctx).Table(d.FeedbackTable()).
  710. Select("feedback_type, user_id, item_id, time_stamp, comment").
  711. Order("feedback_type, user_id, item_id")
  712. if len(scan.FeedbackTypes) > 0 {
  713. tx.Where("feedback_type IN ?", scan.FeedbackTypes)
  714. }
  715. if scan.BeginTime != nil {
  716. tx.Where("time_stamp >= ?", d.convertTimeZone(scan.BeginTime))
  717. }
  718. if scan.EndTime != nil {
  719. tx.Where("time_stamp <= ?", d.convertTimeZone(scan.EndTime))
  720. }
  721. if scan.BeginUserId != nil {
  722. tx.Where("user_id >= ?", scan.BeginUserId)
  723. }
  724. if scan.EndUserId != nil {
  725. tx.Where("user_id <= ?", scan.EndUserId)
  726. }
  727. result, err := tx.Rows()
  728. if err != nil {
  729. errChan <- errors.Trace(err)
  730. return
  731. }
  732. // fetch result
  733. feedbacks := make([]Feedback, 0, batchSize)
  734. defer result.Close()
  735. for result.Next() {
  736. var feedback Feedback
  737. if err = d.gormDB.ScanRows(result, &feedback); err != nil {
  738. errChan <- errors.Trace(err)
  739. return
  740. }
  741. feedbacks = append(feedbacks, feedback)
  742. if len(feedbacks) == batchSize {
  743. feedbackChan <- feedbacks
  744. feedbacks = make([]Feedback, 0, batchSize)
  745. }
  746. }
  747. if len(feedbacks) > 0 {
  748. feedbackChan <- feedbacks
  749. }
  750. errChan <- nil
  751. }()
  752. return feedbackChan, errChan
  753. }
  754. // GetUserItemFeedback gets a feedback by user id and item id from MySQL.
  755. func (d *SQLDatabase) GetUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) ([]Feedback, error) {
  756. tx := d.gormDB.WithContext(ctx).Table(d.FeedbackTable()).
  757. Select("feedback_type, user_id, item_id, time_stamp, comment").
  758. Where("user_id = ? AND item_id = ?", userId, itemId)
  759. if len(feedbackTypes) > 0 {
  760. tx.Where("feedback_type IN ?", feedbackTypes)
  761. }
  762. result, err := tx.Rows()
  763. if err != nil {
  764. return nil, errors.Trace(err)
  765. }
  766. feedbacks := make([]Feedback, 0)
  767. defer result.Close()
  768. for result.Next() {
  769. var feedback Feedback
  770. if err = d.gormDB.ScanRows(result, &feedback); err != nil {
  771. return nil, errors.Trace(err)
  772. }
  773. feedbacks = append(feedbacks, feedback)
  774. }
  775. return feedbacks, nil
  776. }
  777. // DeleteUserItemFeedback deletes a feedback by user id and item id from MySQL.
  778. func (d *SQLDatabase) DeleteUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) (int, error) {
  779. tx := d.gormDB.WithContext(ctx).Where("user_id = ? AND item_id = ?", userId, itemId)
  780. if len(feedbackTypes) > 0 {
  781. tx.Where("feedback_type IN ?", feedbackTypes)
  782. }
  783. tx.Delete(&Feedback{})
  784. if tx.Error != nil {
  785. return 0, errors.Trace(tx.Error)
  786. }
  787. if tx.Error != nil {
  788. return 0, errors.Trace(tx.Error)
  789. }
  790. return int(tx.RowsAffected), nil
  791. }
  792. func (d *SQLDatabase) convertTimeZone(timestamp *time.Time) time.Time {
  793. switch d.driver {
  794. case SQLite:
  795. return timestamp.In(time.UTC)
  796. default:
  797. return *timestamp
  798. }
  799. }