database.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. // Copyright 2021 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package data
  15. import (
  16. "context"
  17. "encoding/json"
  18. "reflect"
  19. "sort"
  20. "strings"
  21. "time"
  22. "github.com/XSAM/otelsql"
  23. "github.com/juju/errors"
  24. "github.com/samber/lo"
  25. "github.com/zhenghaoz/gorse/base/jsonutil"
  26. "github.com/zhenghaoz/gorse/base/log"
  27. "github.com/zhenghaoz/gorse/storage"
  28. "go.mongodb.org/mongo-driver/mongo"
  29. "go.mongodb.org/mongo-driver/mongo/options"
  30. "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
  31. "go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo"
  32. semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
  33. "gorm.io/driver/mysql"
  34. "gorm.io/driver/postgres"
  35. "gorm.io/driver/sqlite"
  36. "gorm.io/gorm"
  37. "gorm.io/gorm/logger"
  38. "moul.io/zapgorm2"
  39. )
  40. const (
  41. maxIdleConns = 64
  42. maxOpenConns = 64
  43. maxLifetime = time.Minute
  44. )
  45. var (
  46. ErrUserNotExist = errors.NotFoundf("user")
  47. ErrItemNotExist = errors.NotFoundf("item")
  48. ErrNoDatabase = errors.NotAssignedf("database")
  49. )
  50. // ValidateLabels checks if labels are valid. Labels are valid if consists of:
  51. // - []string slice of strings
  52. // - []float64 slice of numbers
  53. // - map[string]any map of strings to valid labels or float64
  54. func ValidateLabels(o any) error {
  55. if o == nil {
  56. return nil
  57. }
  58. switch labels := o.(type) {
  59. case []any: // must be []string or []float64
  60. if len(labels) == 0 {
  61. return nil
  62. }
  63. switch labels[0].(type) {
  64. case string:
  65. for _, val := range labels {
  66. if _, ok := val.(string); !ok {
  67. return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels))
  68. }
  69. }
  70. case json.Number:
  71. for _, val := range labels {
  72. if _, ok := val.(json.Number); !ok {
  73. return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels))
  74. }
  75. }
  76. default:
  77. return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels))
  78. }
  79. return nil
  80. case map[string]any:
  81. for _, val := range labels {
  82. if err := ValidateLabels(val); err != nil {
  83. return err
  84. }
  85. }
  86. return nil
  87. case string, json.Number:
  88. return nil
  89. default:
  90. return errors.Errorf("unsupported type in labels: %v", reflect.TypeOf(labels))
  91. }
  92. }
  93. // Item stores meta data about item.
  94. type Item struct {
  95. ItemId string `gorm:"primaryKey" mapstructure:"item_id"`
  96. IsHidden bool `mapstructure:"is_hidden"`
  97. Categories []string `gorm:"serializer:json" mapstructure:"categories"`
  98. Timestamp time.Time `gorm:"column:time_stamp" mapstructure:"timestamp"`
  99. Labels any `gorm:"serializer:json" mapstructure:"labels"`
  100. Comment string `mapsstructure:"comment"`
  101. }
  102. // ItemPatch is the modification on an item.
  103. type ItemPatch struct {
  104. IsHidden *bool
  105. Categories []string
  106. Timestamp *time.Time
  107. Labels any
  108. Comment *string
  109. }
  110. // User stores meta data about user.
  111. type User struct {
  112. UserId string `gorm:"primaryKey" mapstructure:"user_id"`
  113. Labels any `gorm:"serializer:json" mapstructure:"labels"`
  114. Subscribe []string `gorm:"serializer:json" mapstructure:"subscribe"`
  115. Comment string `mapstructure:"comment"`
  116. }
  117. // UserPatch is the modification on a user.
  118. type UserPatch struct {
  119. Labels any
  120. Subscribe []string
  121. Comment *string
  122. }
  123. // FeedbackKey identifies feedback.
  124. type FeedbackKey struct {
  125. FeedbackType string `gorm:"column:feedback_type" mapstructure:"feedback_type"`
  126. UserId string `gorm:"column:user_id" mapstructure:"user_id"`
  127. ItemId string `gorm:"column:item_id" mapstructure:"item_id"`
  128. }
  129. // Feedback stores feedback.
  130. type Feedback struct {
  131. FeedbackKey `gorm:"embedded" mapstructure:",squash"`
  132. Timestamp time.Time `gorm:"column:time_stamp" mapsstructure:"timestamp"`
  133. Comment string `gorm:"column:comment" mapsstructure:"comment"`
  134. }
  135. // SortFeedbacks sorts feedback from latest to oldest.
  136. func SortFeedbacks(feedback []Feedback) {
  137. sort.Sort(feedbackSorter(feedback))
  138. }
  139. type feedbackSorter []Feedback
  140. func (sorter feedbackSorter) Len() int {
  141. return len(sorter)
  142. }
  143. func (sorter feedbackSorter) Less(i, j int) bool {
  144. return sorter[i].Timestamp.After(sorter[j].Timestamp)
  145. }
  146. func (sorter feedbackSorter) Swap(i, j int) {
  147. sorter[i], sorter[j] = sorter[j], sorter[i]
  148. }
  149. type ScanOptions struct {
  150. BeginUserId *string
  151. EndUserId *string
  152. BeginTime *time.Time
  153. EndTime *time.Time
  154. FeedbackTypes []string
  155. }
  156. type ScanOption func(options *ScanOptions)
  157. // WithBeginUserId sets the begin user id. The begin user id is included in the result.
  158. func WithBeginUserId(userId string) ScanOption {
  159. return func(options *ScanOptions) {
  160. options.BeginUserId = &userId
  161. }
  162. }
  163. // WithEndUserId sets the end user id. The end user id is included in the result.
  164. func WithEndUserId(userId string) ScanOption {
  165. return func(options *ScanOptions) {
  166. options.EndUserId = &userId
  167. }
  168. }
  169. // WithBeginTime sets the begin time. The begin time is included in the result.
  170. func WithBeginTime(t time.Time) ScanOption {
  171. return func(options *ScanOptions) {
  172. options.BeginTime = &t
  173. }
  174. }
  175. // WithEndTime sets the end time. The end time is included in the result.
  176. func WithEndTime(t time.Time) ScanOption {
  177. return func(options *ScanOptions) {
  178. options.EndTime = &t
  179. }
  180. }
  181. // WithFeedbackTypes sets the feedback types.
  182. func WithFeedbackTypes(feedbackTypes ...string) ScanOption {
  183. return func(options *ScanOptions) {
  184. options.FeedbackTypes = feedbackTypes
  185. }
  186. }
  187. func NewScanOptions(opts ...ScanOption) ScanOptions {
  188. options := ScanOptions{}
  189. for _, opt := range opts {
  190. if opt != nil {
  191. opt(&options)
  192. }
  193. }
  194. return options
  195. }
  196. type Database interface {
  197. Init() error
  198. Ping() error
  199. Close() error
  200. Purge() error
  201. BatchInsertItems(ctx context.Context, items []Item) error
  202. BatchGetItems(ctx context.Context, itemIds []string) ([]Item, error)
  203. DeleteItem(ctx context.Context, itemId string) error
  204. GetItem(ctx context.Context, itemId string) (Item, error)
  205. ModifyItem(ctx context.Context, itemId string, patch ItemPatch) error
  206. GetItems(ctx context.Context, cursor string, n int, beginTime *time.Time) (string, []Item, error)
  207. GetItemFeedback(ctx context.Context, itemId string, feedbackTypes ...string) ([]Feedback, error)
  208. BatchInsertUsers(ctx context.Context, users []User) error
  209. DeleteUser(ctx context.Context, userId string) error
  210. GetUser(ctx context.Context, userId string) (User, error)
  211. ModifyUser(ctx context.Context, userId string, patch UserPatch) error
  212. GetUsers(ctx context.Context, cursor string, n int) (string, []User, error)
  213. GetUserFeedback(ctx context.Context, userId string, endTime *time.Time, feedbackTypes ...string) ([]Feedback, error)
  214. GetUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) ([]Feedback, error)
  215. DeleteUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) (int, error)
  216. BatchInsertFeedback(ctx context.Context, feedback []Feedback, insertUser, insertItem, overwrite bool) error
  217. GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error)
  218. GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error)
  219. GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error)
  220. GetFeedbackStream(ctx context.Context, batchSize int, options ...ScanOption) (chan []Feedback, chan error)
  221. }
  222. // Open a connection to a database.
  223. func Open(path, tablePrefix string) (Database, error) {
  224. var err error
  225. if strings.HasPrefix(path, storage.MySQLPrefix) {
  226. name := path[len(storage.MySQLPrefix):]
  227. // probe isolation variable name
  228. isolationVarName, err := storage.ProbeMySQLIsolationVariableName(name)
  229. if err != nil {
  230. return nil, errors.Trace(err)
  231. }
  232. // append parameters
  233. if name, err = storage.AppendMySQLParams(name, map[string]string{
  234. "sql_mode": "'ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION'",
  235. isolationVarName: "'READ-UNCOMMITTED'",
  236. "parseTime": "true",
  237. }); err != nil {
  238. return nil, errors.Trace(err)
  239. }
  240. // connect to database
  241. database := new(SQLDatabase)
  242. database.driver = MySQL
  243. database.TablePrefix = storage.TablePrefix(tablePrefix)
  244. if database.client, err = otelsql.Open("mysql", name,
  245. otelsql.WithAttributes(semconv.DBSystemMySQL),
  246. otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}),
  247. ); err != nil {
  248. return nil, errors.Trace(err)
  249. }
  250. database.gormDB, err = gorm.Open(mysql.New(mysql.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix))
  251. if err != nil {
  252. return nil, errors.Trace(err)
  253. }
  254. return database, nil
  255. } else if strings.HasPrefix(path, storage.PostgresPrefix) || strings.HasPrefix(path, storage.PostgreSQLPrefix) {
  256. database := new(SQLDatabase)
  257. database.driver = Postgres
  258. database.TablePrefix = storage.TablePrefix(tablePrefix)
  259. if database.client, err = otelsql.Open("postgres", path,
  260. otelsql.WithAttributes(semconv.DBSystemPostgreSQL),
  261. otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}),
  262. ); err != nil {
  263. return nil, errors.Trace(err)
  264. }
  265. database.client.SetMaxIdleConns(maxIdleConns)
  266. database.client.SetMaxOpenConns(maxOpenConns)
  267. database.client.SetConnMaxLifetime(maxLifetime)
  268. database.gormDB, err = gorm.Open(postgres.New(postgres.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix))
  269. if err != nil {
  270. return nil, errors.Trace(err)
  271. }
  272. return database, nil
  273. } else if strings.HasPrefix(path, storage.MongoPrefix) || strings.HasPrefix(path, storage.MongoSrvPrefix) {
  274. // connect to database
  275. database := new(MongoDB)
  276. opts := options.Client()
  277. opts.Monitor = otelmongo.NewMonitor()
  278. opts.ApplyURI(path)
  279. if database.client, err = mongo.Connect(context.Background(), opts); err != nil {
  280. return nil, errors.Trace(err)
  281. }
  282. // parse DSN and extract database name
  283. if cs, err := connstring.ParseAndValidate(path); err != nil {
  284. return nil, errors.Trace(err)
  285. } else {
  286. database.dbName = cs.Database
  287. database.TablePrefix = storage.TablePrefix(tablePrefix)
  288. }
  289. return database, nil
  290. } else if strings.HasPrefix(path, storage.SQLitePrefix) {
  291. dataSourceName := path[len(storage.SQLitePrefix):]
  292. // append parameters
  293. if dataSourceName, err = storage.AppendURLParams(dataSourceName, []lo.Tuple2[string, string]{
  294. {"_pragma", "busy_timeout(10000)"},
  295. {"_pragma", "journal_mode(wal)"},
  296. }); err != nil {
  297. return nil, errors.Trace(err)
  298. }
  299. // connect to database
  300. database := new(SQLDatabase)
  301. database.driver = SQLite
  302. database.TablePrefix = storage.TablePrefix(tablePrefix)
  303. if database.client, err = otelsql.Open("sqlite", dataSourceName,
  304. otelsql.WithAttributes(semconv.DBSystemSqlite),
  305. otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}),
  306. ); err != nil {
  307. return nil, errors.Trace(err)
  308. }
  309. gormConfig := storage.NewGORMConfig(tablePrefix)
  310. gormConfig.Logger = &zapgorm2.Logger{
  311. ZapLogger: log.Logger(),
  312. LogLevel: logger.Warn,
  313. SlowThreshold: 10 * time.Second,
  314. SkipCallerLookup: false,
  315. IgnoreRecordNotFoundError: false,
  316. }
  317. database.gormDB, err = gorm.Open(sqlite.Dialector{Conn: database.client}, gormConfig)
  318. if err != nil {
  319. return nil, errors.Trace(err)
  320. }
  321. return database, nil
  322. }
  323. return nil, errors.Errorf("Unknown database: %s", path)
  324. }