torch.py 584 B

12345678910111213141516171819202122
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from packaging import version as pkg_version
  5. import torch
  6. def required_torch_version(min_version=None, max_version=None):
  7. assert min_version or max_version, "Must provide a min_version or max_version argument"
  8. torch_version = pkg_version.parse(torch.__version__)
  9. if min_version and pkg_version.parse(str(min_version)) > torch_version:
  10. return False
  11. if max_version and pkg_version.parse(str(max_version)) < torch_version:
  12. return False
  13. return True