model.go 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. // Copyright 2020 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package model
  15. import (
  16. "github.com/zhenghaoz/gorse/base"
  17. )
  18. // Model is the interface for all models. Any model in this
  19. // package should implement it.
  20. type Model interface {
  21. SetParams(params Params)
  22. GetParams() Params
  23. GetParamsGrid(withSize bool) ParamsGrid
  24. Clear()
  25. Invalid() bool
  26. }
  27. // BaseModel model must be included by every recommendation model. Hyper-parameters,
  28. // ID sets, random generator and fitting options are managed the BaseModel model.
  29. type BaseModel struct {
  30. Params Params // Hyper-parameters
  31. rng base.RandomGenerator // Random generator
  32. randState int64 // Random seed
  33. }
  34. // SetParams sets hyper-parameters for the BaseModel model.
  35. func (model *BaseModel) SetParams(params Params) {
  36. model.Params = params
  37. model.randState = model.Params.GetInt64(RandomState, 0)
  38. model.rng = base.NewRandomGenerator(model.randState)
  39. }
  40. // GetParams returns all hyper-parameters.
  41. func (model *BaseModel) GetParams() Params {
  42. return model.Params
  43. }
  44. func (model *BaseModel) GetRandomGenerator() base.RandomGenerator {
  45. return model.rng
  46. }