util.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. from deepspeed.git_version_info import torch_info
  3. def required_torch_version():
  4. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  5. TORCH_MINOR = int(torch.__version__.split('.')[1])
  6. if TORCH_MAJOR >= 1 and TORCH_MINOR >= 8:
  7. return True
  8. else:
  9. return False
  10. def bf16_required_version_check():
  11. split_version = lambda x: map(int, x.split('.')[:2])
  12. TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version'])
  13. NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version'])
  14. CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
  15. if (TORCH_MAJOR > 1 or
  16. (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and (
  17. NCCL_MAJOR > 2 or
  18. (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and torch_info['bf16_support']:
  19. return True
  20. else:
  21. return False
  22. def required_minimum_torch_version(major_version, minor_version):
  23. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  24. TORCH_MINOR = int(torch.__version__.split('.')[1])
  25. if TORCH_MAJOR < major_version:
  26. return False
  27. return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version
  28. def required_maximum_torch_version(major_version, minor_version):
  29. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  30. TORCH_MINOR = int(torch.__version__.split('.')[1])
  31. if TORCH_MAJOR > major_version:
  32. return False
  33. return TORCH_MAJOR < major_version or TORCH_MINOR <= minor_version