conftest.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # tests directory-specific settings - this file is run automatically by pytest before any tests are run
  5. import sys
  6. import pytest
  7. import os
  8. from os.path import abspath, dirname, join
  9. import torch
  10. import warnings
  11. # Set this environment variable for the T5 inference unittest(s) (e.g. google/t5-v1_1-small)
  12. os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
  13. # allow having multiple repository checkouts and not needing to remember to rerun
  14. # 'pip install -e .[dev]' when switching between checkouts and running tests.
  15. git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
  16. sys.path.insert(1, git_repo_path)
  17. def pytest_configure(config):
  18. config.option.color = "yes"
  19. config.option.durations = 0
  20. config.option.durations_min = 1
  21. config.option.verbose = True
  22. def pytest_addoption(parser):
  23. parser.addoption("--torch_ver", default=None, type=str)
  24. parser.addoption("--cuda_ver", default=None, type=str)
  25. def validate_version(expected, found):
  26. version_depth = expected.count('.') + 1
  27. found = '.'.join(found.split('.')[:version_depth])
  28. return found == expected
  29. @pytest.fixture(scope="session", autouse=True)
  30. def check_environment(pytestconfig):
  31. expected_torch_version = pytestconfig.getoption("torch_ver")
  32. expected_cuda_version = pytestconfig.getoption("cuda_ver")
  33. if expected_torch_version is None:
  34. warnings.warn(
  35. "Running test without verifying torch version, please provide an expected torch version with --torch_ver")
  36. elif not validate_version(expected_torch_version, torch.__version__):
  37. pytest.exit(
  38. f"expected torch version {expected_torch_version} did not match found torch version {torch.__version__}",
  39. returncode=2)
  40. if expected_cuda_version is None:
  41. warnings.warn(
  42. "Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver")
  43. elif not validate_version(expected_cuda_version, torch.version.cuda):
  44. pytest.exit(
  45. f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}",
  46. returncode=2)
  47. # Override of pytest "runtest" for DistributedTest class
  48. # This hook is run before the default pytest_runtest_call
  49. @pytest.hookimpl(tryfirst=True)
  50. def pytest_runtest_call(item):
  51. # We want to use our own launching function for distributed tests
  52. if getattr(item.cls, "is_dist_test", False):
  53. dist_test_class = item.cls()
  54. dist_test_class(item._request)
  55. item.runtest = lambda: True # Dummy function so test is not run twice
  56. # We allow DistributedTest to reuse distributed environments. When the last
  57. # test for a class is run, we want to make sure those distributed environments
  58. # are destroyed.
  59. def pytest_runtest_teardown(item, nextitem):
  60. if getattr(item.cls, "reuse_dist_env", False) and not nextitem:
  61. dist_test_class = item.cls()
  62. for num_procs, pool in dist_test_class._pool_cache.items():
  63. dist_test_class._close_pool(pool, num_procs, force=True)
  64. @pytest.hookimpl(tryfirst=True)
  65. def pytest_fixture_setup(fixturedef, request):
  66. if getattr(fixturedef.func, "is_dist_fixture", False):
  67. dist_fixture_class = fixturedef.func()
  68. dist_fixture_class(request)