test_models.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from __future__ import annotations
  2. from unittest.mock import MagicMock, Mock, patch
  3. import pytest
  4. from sweagent.agent.models import ModelArguments, OpenAIModel, TogetherModel
  5. @pytest.fixture()
  6. def openai_mock_client():
  7. model = Mock()
  8. response = Mock()
  9. choice = Mock()
  10. choice.message.content = "test"
  11. response.choices = [choice]
  12. response.usage.prompt_tokens = 10
  13. response.usage.completion_tokens = 10
  14. model.chat.completions.create = MagicMock(return_value=response)
  15. return model
  16. @pytest.fixture()
  17. def mock_together_response():
  18. return {
  19. "choices": [{"text": "<human>Hello</human>"}],
  20. "usage": {"prompt_tokens": 10, "completion_tokens": 10},
  21. }
  22. TEST_HISTORY = [{"role": "system", "content": "Hello, how are you?"}]
  23. def test_openai_model(openai_mock_client):
  24. for model_name in list(OpenAIModel.MODELS) + list(OpenAIModel.SHORTCUTS):
  25. TEST_MODEL_ARGUMENTS = ModelArguments(model_name)
  26. with patch("sweagent.agent.models.keys_config"), patch("sweagent.agent.models.OpenAI"):
  27. model = OpenAIModel(TEST_MODEL_ARGUMENTS, [])
  28. model.client = openai_mock_client
  29. model.query(TEST_HISTORY)
  30. @pytest.mark.parametrize("model_name", list(TogetherModel.MODELS) + list(TogetherModel.SHORTCUTS))
  31. def test_together_model(mock_together_response, model_name):
  32. with patch("sweagent.agent.models.keys_config"), patch("sweagent.agent.models.together") as mock_together:
  33. mock_together.version = "1.1.0"
  34. mock_together.Complete.create.return_value = mock_together_response
  35. model_args = ModelArguments(model_name)
  36. model = TogetherModel(model_args, [])
  37. model.query(TEST_HISTORY)