123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- // Copyright 2021 gorse Project Authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package data
- import (
- "context"
- "encoding/json"
- "reflect"
- "sort"
- "strings"
- "time"
- "github.com/XSAM/otelsql"
- "github.com/juju/errors"
- "github.com/samber/lo"
- "github.com/zhenghaoz/gorse/base/jsonutil"
- "github.com/zhenghaoz/gorse/base/log"
- "github.com/zhenghaoz/gorse/storage"
- "go.mongodb.org/mongo-driver/mongo"
- "go.mongodb.org/mongo-driver/mongo/options"
- "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
- "go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo"
- semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
- "gorm.io/driver/mysql"
- "gorm.io/driver/postgres"
- "gorm.io/driver/sqlite"
- "gorm.io/gorm"
- "gorm.io/gorm/logger"
- "moul.io/zapgorm2"
- )
- const (
- maxIdleConns = 64
- maxOpenConns = 64
- maxLifetime = time.Minute
- )
- var (
- ErrUserNotExist = errors.NotFoundf("user")
- ErrItemNotExist = errors.NotFoundf("item")
- ErrNoDatabase = errors.NotAssignedf("database")
- )
- // ValidateLabels checks if labels are valid. Labels are valid if consists of:
- // - []string slice of strings
- // - []float64 slice of numbers
- // - map[string]any map of strings to valid labels or float64
- func ValidateLabels(o any) error {
- if o == nil {
- return nil
- }
- switch labels := o.(type) {
- case []any: // must be []string or []float64
- if len(labels) == 0 {
- return nil
- }
- switch labels[0].(type) {
- case string:
- for _, val := range labels {
- if _, ok := val.(string); !ok {
- return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels))
- }
- }
- case json.Number:
- for _, val := range labels {
- if _, ok := val.(json.Number); !ok {
- return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels))
- }
- }
- default:
- return errors.Errorf("unsupported labels: %v", jsonutil.MustMarshal(labels))
- }
- return nil
- case map[string]any:
- for _, val := range labels {
- if err := ValidateLabels(val); err != nil {
- return err
- }
- }
- return nil
- case string, json.Number:
- return nil
- default:
- return errors.Errorf("unsupported type in labels: %v", reflect.TypeOf(labels))
- }
- }
- // Item stores meta data about item.
- type Item struct {
- ItemId string `gorm:"primaryKey" mapstructure:"item_id"`
- IsHidden bool `mapstructure:"is_hidden"`
- Categories []string `gorm:"serializer:json" mapstructure:"categories"`
- Timestamp time.Time `gorm:"column:time_stamp" mapstructure:"timestamp"`
- Labels any `gorm:"serializer:json" mapstructure:"labels"`
- Comment string `mapsstructure:"comment"`
- }
- // ItemPatch is the modification on an item.
- type ItemPatch struct {
- IsHidden *bool
- Categories []string
- Timestamp *time.Time
- Labels any
- Comment *string
- }
- // User stores meta data about user.
- type User struct {
- UserId string `gorm:"primaryKey" mapstructure:"user_id"`
- Labels any `gorm:"serializer:json" mapstructure:"labels"`
- Subscribe []string `gorm:"serializer:json" mapstructure:"subscribe"`
- Comment string `mapstructure:"comment"`
- }
- // UserPatch is the modification on a user.
- type UserPatch struct {
- Labels any
- Subscribe []string
- Comment *string
- }
- // FeedbackKey identifies feedback.
- type FeedbackKey struct {
- FeedbackType string `gorm:"column:feedback_type" mapstructure:"feedback_type"`
- UserId string `gorm:"column:user_id" mapstructure:"user_id"`
- ItemId string `gorm:"column:item_id" mapstructure:"item_id"`
- }
- // Feedback stores feedback.
- type Feedback struct {
- FeedbackKey `gorm:"embedded" mapstructure:",squash"`
- Timestamp time.Time `gorm:"column:time_stamp" mapsstructure:"timestamp"`
- Comment string `gorm:"column:comment" mapsstructure:"comment"`
- }
- // SortFeedbacks sorts feedback from latest to oldest.
- func SortFeedbacks(feedback []Feedback) {
- sort.Sort(feedbackSorter(feedback))
- }
- type feedbackSorter []Feedback
- func (sorter feedbackSorter) Len() int {
- return len(sorter)
- }
- func (sorter feedbackSorter) Less(i, j int) bool {
- return sorter[i].Timestamp.After(sorter[j].Timestamp)
- }
- func (sorter feedbackSorter) Swap(i, j int) {
- sorter[i], sorter[j] = sorter[j], sorter[i]
- }
- type ScanOptions struct {
- BeginUserId *string
- EndUserId *string
- BeginTime *time.Time
- EndTime *time.Time
- FeedbackTypes []string
- }
- type ScanOption func(options *ScanOptions)
- // WithBeginUserId sets the begin user id. The begin user id is included in the result.
- func WithBeginUserId(userId string) ScanOption {
- return func(options *ScanOptions) {
- options.BeginUserId = &userId
- }
- }
- // WithEndUserId sets the end user id. The end user id is included in the result.
- func WithEndUserId(userId string) ScanOption {
- return func(options *ScanOptions) {
- options.EndUserId = &userId
- }
- }
- // WithBeginTime sets the begin time. The begin time is included in the result.
- func WithBeginTime(t time.Time) ScanOption {
- return func(options *ScanOptions) {
- options.BeginTime = &t
- }
- }
- // WithEndTime sets the end time. The end time is included in the result.
- func WithEndTime(t time.Time) ScanOption {
- return func(options *ScanOptions) {
- options.EndTime = &t
- }
- }
- // WithFeedbackTypes sets the feedback types.
- func WithFeedbackTypes(feedbackTypes ...string) ScanOption {
- return func(options *ScanOptions) {
- options.FeedbackTypes = feedbackTypes
- }
- }
- func NewScanOptions(opts ...ScanOption) ScanOptions {
- options := ScanOptions{}
- for _, opt := range opts {
- if opt != nil {
- opt(&options)
- }
- }
- return options
- }
- type Database interface {
- Init() error
- Ping() error
- Close() error
- Purge() error
- BatchInsertItems(ctx context.Context, items []Item) error
- BatchGetItems(ctx context.Context, itemIds []string) ([]Item, error)
- DeleteItem(ctx context.Context, itemId string) error
- GetItem(ctx context.Context, itemId string) (Item, error)
- ModifyItem(ctx context.Context, itemId string, patch ItemPatch) error
- GetItems(ctx context.Context, cursor string, n int, beginTime *time.Time) (string, []Item, error)
- GetItemFeedback(ctx context.Context, itemId string, feedbackTypes ...string) ([]Feedback, error)
- BatchInsertUsers(ctx context.Context, users []User) error
- DeleteUser(ctx context.Context, userId string) error
- GetUser(ctx context.Context, userId string) (User, error)
- ModifyUser(ctx context.Context, userId string, patch UserPatch) error
- GetUsers(ctx context.Context, cursor string, n int) (string, []User, error)
- GetUserFeedback(ctx context.Context, userId string, endTime *time.Time, feedbackTypes ...string) ([]Feedback, error)
- GetUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) ([]Feedback, error)
- DeleteUserItemFeedback(ctx context.Context, userId, itemId string, feedbackTypes ...string) (int, error)
- BatchInsertFeedback(ctx context.Context, feedback []Feedback, insertUser, insertItem, overwrite bool) error
- GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error)
- GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error)
- GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error)
- GetFeedbackStream(ctx context.Context, batchSize int, options ...ScanOption) (chan []Feedback, chan error)
- }
- // Open a connection to a database.
- func Open(path, tablePrefix string) (Database, error) {
- var err error
- if strings.HasPrefix(path, storage.MySQLPrefix) {
- name := path[len(storage.MySQLPrefix):]
- // probe isolation variable name
- isolationVarName, err := storage.ProbeMySQLIsolationVariableName(name)
- if err != nil {
- return nil, errors.Trace(err)
- }
- // append parameters
- if name, err = storage.AppendMySQLParams(name, map[string]string{
- "sql_mode": "'ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION'",
- isolationVarName: "'READ-UNCOMMITTED'",
- "parseTime": "true",
- }); err != nil {
- return nil, errors.Trace(err)
- }
- // connect to database
- database := new(SQLDatabase)
- database.driver = MySQL
- database.TablePrefix = storage.TablePrefix(tablePrefix)
- if database.client, err = otelsql.Open("mysql", name,
- otelsql.WithAttributes(semconv.DBSystemMySQL),
- otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}),
- ); err != nil {
- return nil, errors.Trace(err)
- }
- database.gormDB, err = gorm.Open(mysql.New(mysql.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix))
- if err != nil {
- return nil, errors.Trace(err)
- }
- return database, nil
- } else if strings.HasPrefix(path, storage.PostgresPrefix) || strings.HasPrefix(path, storage.PostgreSQLPrefix) {
- database := new(SQLDatabase)
- database.driver = Postgres
- database.TablePrefix = storage.TablePrefix(tablePrefix)
- if database.client, err = otelsql.Open("postgres", path,
- otelsql.WithAttributes(semconv.DBSystemPostgreSQL),
- otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}),
- ); err != nil {
- return nil, errors.Trace(err)
- }
- database.client.SetMaxIdleConns(maxIdleConns)
- database.client.SetMaxOpenConns(maxOpenConns)
- database.client.SetConnMaxLifetime(maxLifetime)
- database.gormDB, err = gorm.Open(postgres.New(postgres.Config{Conn: database.client}), storage.NewGORMConfig(tablePrefix))
- if err != nil {
- return nil, errors.Trace(err)
- }
- return database, nil
- } else if strings.HasPrefix(path, storage.MongoPrefix) || strings.HasPrefix(path, storage.MongoSrvPrefix) {
- // connect to database
- database := new(MongoDB)
- opts := options.Client()
- opts.Monitor = otelmongo.NewMonitor()
- opts.ApplyURI(path)
- if database.client, err = mongo.Connect(context.Background(), opts); err != nil {
- return nil, errors.Trace(err)
- }
- // parse DSN and extract database name
- if cs, err := connstring.ParseAndValidate(path); err != nil {
- return nil, errors.Trace(err)
- } else {
- database.dbName = cs.Database
- database.TablePrefix = storage.TablePrefix(tablePrefix)
- }
- return database, nil
- } else if strings.HasPrefix(path, storage.SQLitePrefix) {
- dataSourceName := path[len(storage.SQLitePrefix):]
- // append parameters
- if dataSourceName, err = storage.AppendURLParams(dataSourceName, []lo.Tuple2[string, string]{
- {"_pragma", "busy_timeout(10000)"},
- {"_pragma", "journal_mode(wal)"},
- }); err != nil {
- return nil, errors.Trace(err)
- }
- // connect to database
- database := new(SQLDatabase)
- database.driver = SQLite
- database.TablePrefix = storage.TablePrefix(tablePrefix)
- if database.client, err = otelsql.Open("sqlite", dataSourceName,
- otelsql.WithAttributes(semconv.DBSystemSqlite),
- otelsql.WithSpanOptions(otelsql.SpanOptions{DisableErrSkip: true}),
- ); err != nil {
- return nil, errors.Trace(err)
- }
- gormConfig := storage.NewGORMConfig(tablePrefix)
- gormConfig.Logger = &zapgorm2.Logger{
- ZapLogger: log.Logger(),
- LogLevel: logger.Warn,
- SlowThreshold: 10 * time.Second,
- SkipCallerLookup: false,
- IgnoreRecordNotFoundError: false,
- }
- database.gormDB, err = gorm.Open(sqlite.Dialector{Conn: database.client}, gormConfig)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return database, nil
- }
- return nil, errors.Errorf("Unknown database: %s", path)
- }
|