language_model_arguments.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from dataclasses import dataclass, field
  2. @dataclass
  3. class LanguageModelHandlerArguments:
  4. lm_model_name: str = field(
  5. default="HuggingFaceTB/SmolLM-360M-Instruct",
  6. metadata={
  7. "help": "The pretrained language model to use. Default is 'HuggingFaceTB/SmolLM-360M-Instruct'."
  8. },
  9. )
  10. lm_device: str = field(
  11. default="cuda",
  12. metadata={
  13. "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
  14. },
  15. )
  16. lm_torch_dtype: str = field(
  17. default="float16",
  18. metadata={
  19. "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
  20. },
  21. )
  22. user_role: str = field(
  23. default="user",
  24. metadata={
  25. "help": "Role assigned to the user in the chat context. Default is 'user'."
  26. },
  27. )
  28. init_chat_role: str = field(
  29. default="system",
  30. metadata={
  31. "help": "Initial role for setting up the chat context. Default is 'system'."
  32. },
  33. )
  34. init_chat_prompt: str = field(
  35. default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
  36. metadata={
  37. "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
  38. },
  39. )
  40. lm_gen_max_new_tokens: int = field(
  41. default=128,
  42. metadata={
  43. "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
  44. },
  45. )
  46. lm_gen_min_new_tokens: int = field(
  47. default=0,
  48. metadata={
  49. "help": "Minimum number of new tokens to generate in a single completion. Default is 0."
  50. },
  51. )
  52. lm_gen_temperature: float = field(
  53. default=0.0,
  54. metadata={
  55. "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
  56. },
  57. )
  58. lm_gen_do_sample: bool = field(
  59. default=False,
  60. metadata={
  61. "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
  62. },
  63. )
  64. chat_size: int = field(
  65. default=2,
  66. metadata={
  67. "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
  68. },
  69. )