random.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. // Copyright 2020 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package base
  15. import (
  16. "math/rand"
  17. "sync"
  18. mapset "github.com/deckarep/golang-set/v2"
  19. )
  20. // RandomGenerator is the random generator for gorse.
  21. type RandomGenerator struct {
  22. *rand.Rand
  23. }
  24. // NewRandomGenerator creates a RandomGenerator.
  25. func NewRandomGenerator(seed int64) RandomGenerator {
  26. return RandomGenerator{rand.New(rand.NewSource(int64(seed)))}
  27. }
  28. // UniformVector makes a vec filled with uniform random floats,
  29. func (rng RandomGenerator) UniformVector(size int, low, high float32) []float32 {
  30. ret := make([]float32, size)
  31. scale := high - low
  32. for i := 0; i < len(ret); i++ {
  33. ret[i] = rng.Float32()*scale + low
  34. }
  35. return ret
  36. }
  37. // NewNormalVector makes a vec filled with normal random floats.
  38. func (rng RandomGenerator) NewNormalVector(size int, mean, stdDev float32) []float32 {
  39. ret := make([]float32, size)
  40. for i := 0; i < len(ret); i++ {
  41. ret[i] = float32(rng.NormFloat64())*stdDev + mean
  42. }
  43. return ret
  44. }
  45. // NormalMatrix makes a matrix filled with normal random floats.
  46. func (rng RandomGenerator) NormalMatrix(row, col int, mean, stdDev float32) [][]float32 {
  47. ret := make([][]float32, row)
  48. for i := range ret {
  49. ret[i] = rng.NewNormalVector(col, mean, stdDev)
  50. }
  51. return ret
  52. }
  53. func (rng RandomGenerator) NormalVector(size int, mean, stdDev float32) []float32 {
  54. ret := make([]float32, size)
  55. for i := 0; i < len(ret); i++ {
  56. ret[i] = float32(rng.NormFloat64())*stdDev + mean
  57. }
  58. return ret
  59. }
  60. // UniformMatrix makes a matrix filled with uniform random floats.
  61. func (rng RandomGenerator) UniformMatrix(row, col int, low, high float32) [][]float32 {
  62. ret := make([][]float32, row)
  63. for i := range ret {
  64. ret[i] = rng.UniformVector(col, low, high)
  65. }
  66. return ret
  67. }
  68. // NormalVector64 makes a vec filled with normal random floats.
  69. func (rng RandomGenerator) NormalVector64(size int, mean, stdDev float64) []float64 {
  70. ret := make([]float64, size)
  71. for i := 0; i < len(ret); i++ {
  72. ret[i] = rng.NormFloat64()*stdDev + mean
  73. }
  74. return ret
  75. }
  76. // Sample n values between low and high, but not in exclude.
  77. func (rng RandomGenerator) Sample(low, high, n int, exclude ...mapset.Set[int]) []int {
  78. intervalLength := high - low
  79. excludeSet := mapset.NewSet[int]()
  80. for _, set := range exclude {
  81. excludeSet = excludeSet.Union(set)
  82. }
  83. sampled := make([]int, 0, n)
  84. if n >= intervalLength-excludeSet.Cardinality() {
  85. for i := low; i < high; i++ {
  86. if !excludeSet.Contains(i) {
  87. sampled = append(sampled, i)
  88. excludeSet.Add(i)
  89. }
  90. }
  91. } else {
  92. for len(sampled) < n {
  93. v := rng.Intn(intervalLength) + low
  94. if !excludeSet.Contains(v) {
  95. sampled = append(sampled, v)
  96. excludeSet.Add(v)
  97. }
  98. }
  99. }
  100. return sampled
  101. }
  102. // SampleInt32 n 32bit values between low and high, but not in exclude.
  103. func (rng RandomGenerator) SampleInt32(low, high int32, n int, exclude ...mapset.Set[int32]) []int32 {
  104. intervalLength := high - low
  105. excludeSet := mapset.NewSet[int32]()
  106. for _, set := range exclude {
  107. excludeSet = excludeSet.Union(set)
  108. }
  109. sampled := make([]int32, 0, n)
  110. if n >= int(intervalLength)-excludeSet.Cardinality() {
  111. for i := low; i < high; i++ {
  112. if !excludeSet.Contains(i) {
  113. sampled = append(sampled, i)
  114. excludeSet.Add(i)
  115. }
  116. }
  117. } else {
  118. for len(sampled) < n {
  119. v := rng.Int31n(intervalLength) + low
  120. if !excludeSet.Contains(v) {
  121. sampled = append(sampled, v)
  122. excludeSet.Add(v)
  123. }
  124. }
  125. }
  126. return sampled
  127. }
  128. // lockedSource allows a random number generator to be used by multiple goroutines concurrently.
  129. // The code is very similar to math/rand.lockedSource, which is unfortunately not exposed.
  130. type lockedSource struct {
  131. mut sync.Mutex
  132. src rand.Source
  133. }
  134. // NewRand returns a rand.Rand that is threadsafe.
  135. func NewRand(seed int64) *rand.Rand {
  136. return rand.New(&lockedSource{src: rand.NewSource(seed)})
  137. }
  138. func (r *lockedSource) Int63() (n int64) {
  139. r.mut.Lock()
  140. n = r.src.Int63()
  141. r.mut.Unlock()
  142. return
  143. }
  144. func (r *lockedSource) Seed(seed int64) {
  145. r.mut.Lock()
  146. r.src.Seed(seed)
  147. r.mut.Unlock()
  148. }