test_utils.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from unittest import mock
  4. import pytest
  5. from sweagent import REPO_ROOT
  6. from sweagent.utils.config import Config, convert_path_to_abspath, convert_paths_to_abspath
  7. def test_config_retrieval_fails():
  8. config = Config()
  9. with pytest.raises(KeyError):
  10. config["DOESNTEXIST"]
  11. def test_config_retrieval_get():
  12. config = Config()
  13. assert config.get("asdfasdf", "default") == "default"
  14. def test_retrieve_from_file(tmp_path):
  15. tmp_keys_cfg = tmp_path / "keys.cfg"
  16. tmp_keys_cfg.write_text("MY_KEY: 'VALUE'\n")
  17. config = Config(keys_cfg_path=tmp_keys_cfg)
  18. assert config["MY_KEY"] == "VALUE"
  19. def test_retrieve_from_env(tmp_path):
  20. with mock.patch.dict("os.environ", {"MY_KEY": "VALUE"}):
  21. tmp_keys_cfg = tmp_path / "keys.cfg"
  22. tmp_keys_cfg.write_text("MY_KEY: 'other VALUE'\n")
  23. config = Config(keys_cfg_path=tmp_keys_cfg)
  24. assert config["MY_KEY"] == "VALUE"
  25. def test_retrieve_choices():
  26. """Check that a valueerror is raised if the value is not in the choices."""
  27. match = "Value.*not in.*"
  28. config = Config()
  29. with pytest.raises(ValueError, match=match):
  30. config.get("DOESNTEXIST", default="x", choices=["a", "b", "c"])
  31. with pytest.raises(ValueError, match=match):
  32. with mock.patch.dict("os.environ", {"MY_KEY": "VALUE"}):
  33. config.get("DOESNTEXIST", choices=["a", "b", "c"])
  34. def test_retrieve_choices_config_file(tmp_path):
  35. match = "Value.*not in.*"
  36. tmp_keys_cfg = tmp_path / "keys.cfg"
  37. tmp_keys_cfg.write_text("MY_KEY: 'VALUE'\n")
  38. config = Config(keys_cfg_path=tmp_keys_cfg)
  39. with pytest.raises(ValueError, match=match):
  40. config.get("MY_KEY", choices=["a", "b", "c"])
  41. def test_convert_path_to_abspath():
  42. assert convert_path_to_abspath("sadf") == REPO_ROOT / "sadf"
  43. assert convert_path_to_abspath("/sadf") == Path("/sadf")
  44. def test_convert_paths_to_abspath():
  45. assert convert_paths_to_abspath([Path("sadf"), "/sadf"]) == [REPO_ROOT / "sadf", Path("/sadf")]