test_dependency_torch.py 985 B

1234567891011121314151617181920212223242526272829303132333435
  1. #!/usr/bin/env python
  2. import os
  3. import sys
  4. if __name__ == "__main__":
  5. # Do not import torch for testing purposes.
  6. os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
  7. from ray.rllib.agents.a3c import A2CTrainer
  8. assert "torch" not in sys.modules, \
  9. "`torch` initially present, when it shouldn't!"
  10. # Note: No ray.init(), to test it works without Ray
  11. trainer = A2CTrainer(
  12. env="CartPole-v0",
  13. config={
  14. "framework": "tf",
  15. "num_workers": 0,
  16. # Disable the logger due to a sort-import attempt of torch
  17. # inside the tensorboardX.SummaryWriter class.
  18. "logger_config": {
  19. "type": "ray.tune.logger.NoopLogger",
  20. },
  21. })
  22. trainer.train()
  23. assert "torch" not in sys.modules, \
  24. "`torch` should not be imported after creating and " \
  25. "training A3CTrainer!"
  26. # Clean up.
  27. del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
  28. print("ok")