types.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """This file has types/dataclass definitions that are used in the SWE agent
  2. for exchanging data between different modules/functions/classes.
  3. They oftentimes cannot be defined in the same file where they are used
  4. because of circular dependencies.
  5. """
  6. from __future__ import annotations
  7. import copy
  8. from dataclasses import dataclass
  9. from typing import Any, Literal, TypedDict
  10. from simple_parsing.helpers.serialization.serializable import FrozenSerializable
  11. class TrajectoryStep(TypedDict):
  12. action: str
  13. observation: str
  14. response: str
  15. state: str | None
  16. thought: str
  17. execution_time: float
  18. class _HistoryItem(TypedDict):
  19. role: str
  20. class HistoryItem(_HistoryItem, total=False):
  21. content: str | None
  22. agent: str
  23. is_demo: bool
  24. thought: str
  25. action: str | None
  26. History = list[HistoryItem]
  27. Trajectory = list[TrajectoryStep]
  28. # todo: Make this actually have the dataclasses instead of dict versions
  29. class AgentInfo(TypedDict, total=False):
  30. # same as `APIStats` from models.py
  31. model_stats: dict[str, float]
  32. exit_status: str
  33. submission: str | None
  34. # same as `ReviewerResult`
  35. review: dict[str, Any]
  36. edited_files30: str
  37. edited_files50: str
  38. edited_files70: str
  39. # only if summarizer is used
  40. summarizer: dict
  41. @dataclass
  42. class ReviewSubmission:
  43. """Information that's passed to the reviewer"""
  44. #: Total trajectory (including several retries)
  45. trajectory: Trajectory
  46. #: Aggregate info dict (including several retries)
  47. info: AgentInfo
  48. def to_format_dict(self, *, suffix="") -> dict[str, Any]:
  49. """Return all the data that is used to format the
  50. messages. Trajectory is excluded because it needs special treatment.
  51. """
  52. out = {}
  53. info = copy.deepcopy(self.info)
  54. if not info.get("submission"):
  55. # Observed that not all exit_cost lead to autosubmission
  56. # so sometimes this might be missing.
  57. info["submission"] = ""
  58. for k, v in info.items():
  59. if isinstance(v, str):
  60. out[f"{k}{suffix}"] = v
  61. elif isinstance(v, dict):
  62. for k2, v2 in v.items():
  63. out[f"{k}_{k2}{suffix}"] = v2
  64. return out
  65. @dataclass(frozen=True)
  66. class ReviewerResult(FrozenSerializable):
  67. accept: bool
  68. output: str
  69. messages: list[dict[str, str]]
  70. @dataclass(frozen=True)
  71. class BinaryReviewerResult(FrozenSerializable):
  72. choice: Literal[0, 1]
  73. output: str
  74. messages: list[dict[str, str]]