interceptors.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. // Licensed to the LF AI & Data foundation under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. package client
  17. import (
  18. "context"
  19. "time"
  20. grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
  21. "google.golang.org/grpc"
  22. "google.golang.org/grpc/codes"
  23. "google.golang.org/grpc/metadata"
  24. "google.golang.org/grpc/status"
  25. "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
  26. "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
  27. )
  28. const (
  29. authorizationHeader = `authorization`
  30. identifierHeader = `identifier`
  31. databaseHeader = `dbname`
  32. )
  33. func (c *Client) MetadataUnaryInterceptor() grpc.UnaryClientInterceptor {
  34. return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
  35. ctx = c.metadata(ctx)
  36. ctx = c.state(ctx)
  37. return invoker(ctx, method, req, reply, cc, opts...)
  38. }
  39. }
  40. func (c *Client) metadata(ctx context.Context) context.Context {
  41. for k, v := range c.config.metadataHeaders {
  42. ctx = metadata.AppendToOutgoingContext(ctx, k, v)
  43. }
  44. return ctx
  45. }
  46. func (c *Client) state(ctx context.Context) context.Context {
  47. c.stateMut.RLock()
  48. defer c.stateMut.RUnlock()
  49. if c.currentDB != "" {
  50. ctx = metadata.AppendToOutgoingContext(ctx, databaseHeader, c.currentDB)
  51. }
  52. if c.identifier != "" {
  53. ctx = metadata.AppendToOutgoingContext(ctx, identifierHeader, c.identifier)
  54. }
  55. return ctx
  56. }
  57. // ref: https://github.com/grpc-ecosystem/go-grpc-middleware
  58. type ctxKey int
  59. const (
  60. RetryOnRateLimit ctxKey = iota
  61. )
  62. // RetryOnRateLimitInterceptor returns a new retrying unary client interceptor.
  63. func RetryOnRateLimitInterceptor(maxRetry uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) grpc.UnaryClientInterceptor {
  64. return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
  65. if maxRetry == 0 {
  66. return invoker(parentCtx, method, req, reply, cc, opts...)
  67. }
  68. var lastErr error
  69. for attempt := uint(0); attempt < maxRetry; attempt++ {
  70. _, err := waitRetryBackoff(parentCtx, attempt, maxBackoff, backoffFunc)
  71. if err != nil {
  72. return err
  73. }
  74. lastErr = invoker(parentCtx, method, req, reply, cc, opts...)
  75. rspStatus := getResultStatus(reply)
  76. if retryOnRateLimit(parentCtx) && rspStatus.GetErrorCode() == commonpb.ErrorCode_RateLimit {
  77. continue
  78. }
  79. return lastErr
  80. }
  81. return lastErr
  82. }
  83. }
  84. func retryOnRateLimit(ctx context.Context) bool {
  85. retry, ok := ctx.Value(RetryOnRateLimit).(bool)
  86. if !ok {
  87. return true // default true
  88. }
  89. return retry
  90. }
  91. // getResultStatus returns status of response.
  92. func getResultStatus(reply interface{}) *commonpb.Status {
  93. switch r := reply.(type) {
  94. case *commonpb.Status:
  95. return r
  96. case *milvuspb.MutationResult:
  97. return r.GetStatus()
  98. case *milvuspb.BoolResponse:
  99. return r.GetStatus()
  100. case *milvuspb.SearchResults:
  101. return r.GetStatus()
  102. case *milvuspb.QueryResults:
  103. return r.GetStatus()
  104. case *milvuspb.FlushResponse:
  105. return r.GetStatus()
  106. default:
  107. return nil
  108. }
  109. }
  110. func contextErrToGrpcErr(err error) error {
  111. switch err {
  112. case context.DeadlineExceeded:
  113. return status.Error(codes.DeadlineExceeded, err.Error())
  114. case context.Canceled:
  115. return status.Error(codes.Canceled, err.Error())
  116. default:
  117. return status.Error(codes.Unknown, err.Error())
  118. }
  119. }
  120. func waitRetryBackoff(parentCtx context.Context, attempt uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) (time.Duration, error) {
  121. var waitTime time.Duration
  122. if attempt > 0 {
  123. waitTime = backoffFunc(parentCtx, attempt)
  124. }
  125. if waitTime > 0 {
  126. if waitTime > maxBackoff {
  127. waitTime = maxBackoff
  128. }
  129. timer := time.NewTimer(waitTime)
  130. select {
  131. case <-parentCtx.Done():
  132. timer.Stop()
  133. return waitTime, contextErrToGrpcErr(parentCtx.Err())
  134. case <-timer.C:
  135. }
  136. }
  137. return waitTime, nil
  138. }