dml_channels_test.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. // Licensed to the LF AI & Data foundation under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. package rootcoord
  17. import (
  18. "container/heap"
  19. "context"
  20. "math/rand"
  21. "sync"
  22. "testing"
  23. "github.com/cockroachdb/errors"
  24. "github.com/stretchr/testify/assert"
  25. "github.com/stretchr/testify/require"
  26. "github.com/milvus-io/milvus/internal/util/dependency"
  27. "github.com/milvus-io/milvus/pkg/mq/common"
  28. "github.com/milvus-io/milvus/pkg/mq/msgstream"
  29. "github.com/milvus-io/milvus/pkg/util/funcutil"
  30. "github.com/milvus-io/milvus/pkg/util/paramtable"
  31. )
  32. func TestDmlMsgStream(t *testing.T) {
  33. t.Run("RefCnt", func(t *testing.T) {
  34. dms := &dmlMsgStream{refcnt: 0}
  35. assert.Equal(t, int64(0), dms.RefCnt())
  36. assert.Equal(t, int64(0), dms.Used())
  37. dms.IncRefcnt()
  38. assert.Equal(t, int64(1), dms.RefCnt())
  39. dms.BookUsage()
  40. assert.Equal(t, int64(1), dms.Used())
  41. dms.DecRefCnt()
  42. assert.Equal(t, int64(0), dms.RefCnt())
  43. assert.Equal(t, int64(1), dms.Used())
  44. dms.DecRefCnt()
  45. assert.Equal(t, int64(0), dms.RefCnt())
  46. assert.Equal(t, int64(1), dms.Used())
  47. })
  48. }
  49. func TestChannelsHeap(t *testing.T) {
  50. chanNum := 16
  51. var h channelsHeap
  52. h = make([]*dmlMsgStream, 0, chanNum)
  53. for i := int64(0); i < int64(chanNum); i++ {
  54. dms := &dmlMsgStream{
  55. refcnt: 0,
  56. used: 0,
  57. idx: i,
  58. pos: int(i),
  59. }
  60. h = append(h, dms)
  61. }
  62. check := func(h channelsHeap) bool {
  63. for i := 0; i < chanNum; i++ {
  64. if h[i].pos != i {
  65. return false
  66. }
  67. if i*2+1 < chanNum {
  68. if !h.Less(i, i*2+1) {
  69. t.Log("left", i)
  70. return false
  71. }
  72. }
  73. if i*2+2 < chanNum {
  74. if !h.Less(i, i*2+2) {
  75. t.Log("right", i)
  76. return false
  77. }
  78. }
  79. }
  80. return true
  81. }
  82. heap.Init(&h)
  83. assert.True(t, check(h))
  84. // add usage for all
  85. for i := 0; i < chanNum; i++ {
  86. h[0].BookUsage()
  87. h[0].IncRefcnt()
  88. heap.Fix(&h, 0)
  89. }
  90. assert.True(t, check(h))
  91. for i := 0; i < chanNum; i++ {
  92. assert.EqualValues(t, 1, h[i].RefCnt())
  93. assert.EqualValues(t, 1, h[i].Used())
  94. }
  95. randIdx := rand.Intn(chanNum)
  96. target := h[randIdx]
  97. h[randIdx].DecRefCnt()
  98. heap.Fix(&h, randIdx)
  99. assert.EqualValues(t, 0, target.pos)
  100. next := heap.Pop(&h).(*dmlMsgStream)
  101. assert.Equal(t, target, next)
  102. }
  103. func TestDmlChannels(t *testing.T) {
  104. const (
  105. dmlChanPrefix = "rootcoord-dml"
  106. totalDmlChannelNum = 2
  107. )
  108. ctx, cancel := context.WithCancel(context.Background())
  109. defer cancel()
  110. factory := dependency.NewDefaultFactory(true)
  111. dml := newDmlChannels(ctx, factory, dmlChanPrefix, totalDmlChannelNum)
  112. chanNames := dml.listChannels()
  113. assert.Equal(t, 0, len(chanNames))
  114. randStr := funcutil.RandomString(8)
  115. dml.addChannels(randStr)
  116. assert.Error(t, dml.broadcast([]string{randStr}, nil))
  117. {
  118. _, err := dml.broadcastMark([]string{randStr}, nil)
  119. assert.Error(t, err)
  120. }
  121. dml.removeChannels(randStr)
  122. chans0 := dml.getChannelNames(2)
  123. dml.addChannels(chans0...)
  124. assert.Equal(t, 2, dml.getChannelNum())
  125. chans1 := dml.getChannelNames(1)
  126. dml.addChannels(chans1...)
  127. assert.Equal(t, 2, dml.getChannelNum())
  128. chans2 := dml.getChannelNames(totalDmlChannelNum + 1)
  129. assert.Nil(t, chans2)
  130. dml.removeChannels(chans1...)
  131. assert.Equal(t, 2, dml.getChannelNum())
  132. dml.removeChannels(chans0...)
  133. assert.Equal(t, 0, dml.getChannelNum())
  134. paramtable.Get().Save(Params.CommonCfg.PreCreatedTopicEnabled.Key, "true")
  135. paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic2")
  136. defer paramtable.Get().Reset(Params.CommonCfg.PreCreatedTopicEnabled.Key)
  137. defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key)
  138. assert.Panics(t, func() { newDmlChannels(ctx, factory, dmlChanPrefix, totalDmlChannelNum) })
  139. }
  140. func TestDmChannelsFailure(t *testing.T) {
  141. var wg sync.WaitGroup
  142. wg.Add(1)
  143. t.Run("Test newDmlChannels", func(t *testing.T) {
  144. defer wg.Done()
  145. mockFactory := &FailMessageStreamFactory{}
  146. assert.Panics(t, func() { newDmlChannels(context.TODO(), mockFactory, "test-newdmlchannel-root", 1) })
  147. })
  148. wg.Add(1)
  149. t.Run("Test broadcast", func(t *testing.T) {
  150. defer wg.Done()
  151. mockFactory := &FailMessageStreamFactory{errBroadcast: true}
  152. dml := newDmlChannels(context.TODO(), mockFactory, "test-newdmlchannel-root", 1)
  153. chanName0 := dml.getChannelNames(1)[0]
  154. dml.addChannels(chanName0)
  155. require.Equal(t, 1, dml.getChannelNum())
  156. err := dml.broadcast([]string{chanName0}, nil)
  157. assert.Error(t, err)
  158. v, err := dml.broadcastMark([]string{chanName0}, nil)
  159. assert.Empty(t, v)
  160. assert.Error(t, err)
  161. })
  162. wg.Wait()
  163. }
  164. func TestGetNeedChanNum(t *testing.T) {
  165. paramtable.Get().Save(Params.CommonCfg.PreCreatedTopicEnabled.Key, "true")
  166. defer paramtable.Get().Reset(Params.CommonCfg.PreCreatedTopicEnabled.Key)
  167. chans := map[UniqueID][]string{}
  168. var wg sync.WaitGroup
  169. wg.Add(1)
  170. t.Run("topic were empty", func(t *testing.T) {
  171. defer wg.Done()
  172. paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "")
  173. defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key)
  174. assert.Panics(t, func() {
  175. getNeedChanNum(10, chans)
  176. })
  177. })
  178. wg.Add(1)
  179. t.Run("duplicated topics", func(t *testing.T) {
  180. defer wg.Done()
  181. paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic1")
  182. defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key)
  183. assert.Panics(t, func() {
  184. getNeedChanNum(10, chans)
  185. })
  186. })
  187. wg.Add(1)
  188. t.Run("invalid channel channel that not in the list", func(t *testing.T) {
  189. defer wg.Done()
  190. paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic2")
  191. defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key)
  192. chans[UniqueID(100)] = []string{"rootcoord-dml_0"}
  193. assert.Panics(t, func() {
  194. getNeedChanNum(10, chans)
  195. })
  196. })
  197. wg.Add(1)
  198. t.Run("normal case when pre-created topic", func(t *testing.T) {
  199. defer wg.Done()
  200. paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic2")
  201. defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key)
  202. chans[UniqueID(100)] = []string{"topic1"}
  203. assert.Equal(t, getNeedChanNum(10, chans), 0)
  204. })
  205. wg.Add(1)
  206. t.Run("normal case", func(t *testing.T) {
  207. defer wg.Done()
  208. paramtable.Get().Save(Params.CommonCfg.PreCreatedTopicEnabled.Key, "false")
  209. paramtable.Get().Save(Params.CommonCfg.RootCoordDml.Key, "rootcoord-dml")
  210. defer paramtable.Get().Reset(Params.CommonCfg.RootCoordDml.Key)
  211. chans[UniqueID(100)] = []string{"rootcoord-dml_99"}
  212. assert.Equal(t, getNeedChanNum(10, chans), 100)
  213. })
  214. wg.Wait()
  215. }
  216. // FailMessageStreamFactory mock MessageStreamFactory failure
  217. type FailMessageStreamFactory struct {
  218. msgstream.Factory
  219. errBroadcast bool
  220. }
  221. func (f *FailMessageStreamFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
  222. if f.errBroadcast {
  223. return &FailMsgStream{errBroadcast: true}, nil
  224. }
  225. return nil, errors.New("mocked failure")
  226. }
  227. func (f *FailMessageStreamFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
  228. return nil, errors.New("mocked failure")
  229. }
  230. type FailMsgStream struct {
  231. msgstream.MsgStream
  232. errBroadcast bool
  233. }
  234. func (ms *FailMsgStream) Close() {}
  235. func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil }
  236. func (ms *FailMsgStream) AsProducer(channels []string) {}
  237. func (ms *FailMsgStream) AsReader(channels []string, subName string) {}
  238. func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
  239. return nil
  240. }
  241. func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {}
  242. func (ms *FailMsgStream) GetProduceChannels() []string { return nil }
  243. func (ms *FailMsgStream) Produce(*msgstream.MsgPack) error { return nil }
  244. func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
  245. if ms.errBroadcast {
  246. return nil, errors.New("broadcast error")
  247. }
  248. return nil, nil
  249. }
  250. func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
  251. func (ms *FailMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error {
  252. return nil
  253. }
  254. func (ms *FailMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID, error) {
  255. return nil, nil
  256. }