token_bucket.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import threading
  2. import time
  3. class TokenBucket:
  4. def __init__(self, tpm, timeout=None):
  5. self.capacity = int(tpm) # 令牌桶容量
  6. self.tokens = 0 # 初始令牌数为0
  7. self.rate = int(tpm) / 60 # 令牌每秒生成速率
  8. self.timeout = timeout # 等待令牌超时时间
  9. self.cond = threading.Condition() # 条件变量
  10. self.is_running = True
  11. # 开启令牌生成线程
  12. threading.Thread(target=self._generate_tokens).start()
  13. def _generate_tokens(self):
  14. """生成令牌"""
  15. while self.is_running:
  16. with self.cond:
  17. if self.tokens < self.capacity:
  18. self.tokens += 1
  19. self.cond.notify() # 通知获取令牌的线程
  20. time.sleep(1 / self.rate)
  21. def get_token(self):
  22. """获取令牌"""
  23. with self.cond:
  24. while self.tokens <= 0:
  25. flag = self.cond.wait(self.timeout)
  26. if not flag: # 超时
  27. return False
  28. self.tokens -= 1
  29. return True
  30. def close(self):
  31. self.is_running = False
  32. if __name__ == "__main__":
  33. token_bucket = TokenBucket(20, None) # 创建一个每分钟生产20个tokens的令牌桶
  34. # token_bucket = TokenBucket(20, 0.1)
  35. for i in range(3):
  36. if token_bucket.get_token():
  37. print(f"第{i+1}次请求成功")
  38. token_bucket.close()