params.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. // Copyright 2020 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package model
  15. import (
  16. "encoding/json"
  17. "reflect"
  18. "github.com/zhenghaoz/gorse/base/log"
  19. "go.uber.org/zap"
  20. )
  21. /* ParamName */
  22. // ParamName is the type of hyper-parameter names.
  23. type ParamName string
  24. // Predefined hyper-parameter names
  25. const (
  26. Lr ParamName = "Lr" // learning rate
  27. Reg ParamName = "Reg" // regularization strength
  28. NEpochs ParamName = "NEpochs" // number of epochs
  29. NFactors ParamName = "NFactors" // number of factors
  30. RandomState ParamName = "RandomState" // random state (seed)
  31. InitMean ParamName = "InitMean" // mean of gaussian initial parameter
  32. InitStdDev ParamName = "InitStdDev" // standard deviation of gaussian initial parameter
  33. Alpha ParamName = "Alpha" // weight for negative samples in ALS
  34. Similarity ParamName = "Similarity"
  35. UseFeature ParamName = "UseFeature"
  36. BatchSize ParamName = "BatchSize"
  37. HiddenLayers ParamName = "HiddenLayers"
  38. Optimizer ParamName = "Optimizer"
  39. SGD = "sgd"
  40. Adam = "adam"
  41. )
  42. // Params stores hyper-parameters for an model. It is a map between strings
  43. // (names) and interface{}s (values). For example, hyper-parameters for SVD
  44. // is given by:
  45. //
  46. // base.Params{
  47. // base.Lr: 0.007,
  48. // base.NEpochs: 100,
  49. // base.NFactors: 80,
  50. // base.Reg: 0.1,
  51. // }
  52. type Params map[ParamName]interface{}
  53. // Copy hyper-parameters.
  54. func (parameters Params) Copy() Params {
  55. newParams := make(Params)
  56. for k, v := range parameters {
  57. newParams[k] = v
  58. }
  59. return newParams
  60. }
  61. // GetBool gets a boolean parameter by name. Returns _default if not exists or type doesn't match.
  62. func (parameters Params) GetBool(name ParamName, _default bool) bool {
  63. if val, exist := parameters[name]; exist {
  64. switch val := val.(type) {
  65. case bool:
  66. return val
  67. default:
  68. log.Logger().Error("type mismatch",
  69. zap.String("param_name", string(name)),
  70. zap.String("actual_type", reflect.TypeOf(name).Name()))
  71. }
  72. }
  73. return _default
  74. }
  75. // GetInt gets a integer parameter by name. Returns _default if not exists or type doesn't match.
  76. func (parameters Params) GetInt(name ParamName, _default int) int {
  77. if val, exist := parameters[name]; exist {
  78. switch val := val.(type) {
  79. case int:
  80. return val
  81. default:
  82. log.Logger().Error("type mismatch",
  83. zap.String("param_name", string(name)),
  84. zap.String("actual_type", reflect.TypeOf(name).Name()))
  85. }
  86. }
  87. return _default
  88. }
  89. // GetInt64 gets a int64 parameter by name. Returns _default if not exists or type doesn't match. The
  90. // type will be converted if given int.
  91. func (parameters Params) GetInt64(name ParamName, _default int64) int64 {
  92. if val, exist := parameters[name]; exist {
  93. switch val := val.(type) {
  94. case int64:
  95. return val
  96. case int:
  97. return int64(val)
  98. default:
  99. log.Logger().Error("type mismatch",
  100. zap.String("param_name", string(name)),
  101. zap.String("actual_type", reflect.TypeOf(name).Name()))
  102. }
  103. }
  104. return _default
  105. }
  106. func (parameters Params) GetFloat32(name ParamName, _default float32) float32 {
  107. if val, exist := parameters[name]; exist {
  108. switch val := val.(type) {
  109. case float32:
  110. return val
  111. case float64:
  112. return float32(val)
  113. case int:
  114. return float32(val)
  115. default:
  116. log.Logger().Error("type mismatch",
  117. zap.String("param_name", string(name)),
  118. zap.String("actual_type", reflect.TypeOf(name).Name()))
  119. }
  120. }
  121. return _default
  122. }
  123. // GetString gets a string parameter
  124. func (parameters Params) GetString(name ParamName, _default string) string {
  125. if val, exist := parameters[name]; exist {
  126. return val.(string)
  127. }
  128. return _default
  129. }
  130. func (parameters Params) GetIntSlice(name ParamName, _default []int) []int {
  131. if val, exist := parameters[name]; exist {
  132. switch val := val.(type) {
  133. case []int:
  134. return val
  135. default:
  136. log.Logger().Error("type mismatch",
  137. zap.String("param_name", string(name)),
  138. zap.String("actual_type", reflect.TypeOf(name).Name()))
  139. }
  140. }
  141. return _default
  142. }
  143. func (parameters Params) Overwrite(params Params) Params {
  144. merged := make(Params)
  145. for k, v := range parameters {
  146. merged[k] = v
  147. }
  148. for k, v := range params {
  149. merged[k] = v
  150. }
  151. return merged
  152. }
  153. func (parameters Params) ToString() string {
  154. b, err := json.Marshal(parameters)
  155. if err != nil {
  156. log.Logger().Fatal("failed to convert to string", zap.Error(err))
  157. }
  158. return string(b)
  159. }
  160. // ParamsGrid contains candidate for grid search.
  161. type ParamsGrid map[ParamName][]interface{}
  162. func (grid ParamsGrid) Len() int {
  163. return len(grid)
  164. }
  165. func (grid ParamsGrid) NumCombinations() int {
  166. count := 1
  167. for _, values := range grid {
  168. count *= len(values)
  169. }
  170. return count
  171. }
  172. func (grid ParamsGrid) Fill(_default ParamsGrid) {
  173. for param, values := range _default {
  174. if _, exist := grid[param]; !exist {
  175. grid[param] = values
  176. }
  177. }
  178. }