compiler.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from typing import Union, Callable, Dict, Any
  5. import importlib
  6. import torch
  7. from ..pydantic_v1 import validator
  8. from .config_utils import DeepSpeedConfigModel
  9. COMPILE_CONFIG = "compile"
  10. def is_compile_supported():
  11. return hasattr(torch, "compiler")
  12. def disable(func):
  13. if is_compile_supported():
  14. return torch.compiler.disable(func)
  15. return func
  16. def get_compile_config(param_dict):
  17. if COMPILE_CONFIG in param_dict:
  18. compile_config_dict = param_dict[COMPILE_CONFIG]
  19. else:
  20. compile_config_dict = {}
  21. return CompileConfig(**compile_config_dict)
  22. def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]:
  23. if isinstance(backend, Callable):
  24. return backend
  25. elif isinstance(backend, str):
  26. if backend in torch._dynamo.list_backends(exclude_tags=()):
  27. return backend
  28. # Get module name from backend name
  29. module_name = '.'.join(backend.split('.')[:-1])
  30. fn_name = backend.split('.')[-1]
  31. try:
  32. module = importlib.import_module(module_name)
  33. backend_fn = getattr(module, fn_name)
  34. except ImportError:
  35. raise ValueError(
  36. f"The backend {backend} is not in the list of available backends and could not be imported.")
  37. return backend_fn
  38. raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}")
  39. class CompileConfig(DeepSpeedConfigModel):
  40. """
  41. [EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings.
  42. Please be aware that these features and API designs are experimental and subject to change.
  43. """
  44. enabled: bool = False
  45. """
  46. Enable torch.compile when True.
  47. """
  48. backend: str = "inductor"
  49. """
  50. Passed to `backend` argument of torch.compile.
  51. If the given value is not in torch._dynamo.list_backends(),
  52. DeepSpeed attempts to import and instantiate the module with the given name.
  53. """
  54. kwargs: Dict[str, Any] = {}
  55. """
  56. Passed to `kwargs` argument of torch.compile.
  57. """
  58. @validator("enabled")
  59. def validate_enabled(cls, field_value, values):
  60. if field_value and not is_compile_supported():
  61. raise ValueError("torch.compile is not supported on this version of PyTorch.")
  62. return field_value
  63. class CompiledModuleWrapper(torch.nn.Module):
  64. def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
  65. super().__init__()
  66. assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."
  67. modules = self.__dict__.get('_modules')
  68. modules['wrapped'] = module
  69. self.__dict__['wrapped'] = module
  70. self._is_compiled = False
  71. self._backend = get_backend_fn(compile_config.backend)
  72. self._compile_kwargs = compile_config.kwargs
  73. self._compiler_fn = None
  74. def __getattr__(self, name):
  75. return getattr(self.__dict__['wrapped'], name)
  76. def set_backend(self, backend: Union[str, Callable]):
  77. """Set the backend for torch.compile.
  78. Args:
  79. backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
  80. You can directly pass a function that works as a backend.
  81. See also `backend` field in `CompileConfig` for more details.
  82. """
  83. self._backend = get_backend_fn(backend)
  84. def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
  85. """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
  86. You can also pass a backend name with "backend" key to change the backend.
  87. Args:
  88. kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
  89. """
  90. if "backend" in kwargs:
  91. raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
  92. self._compile_kwargs.update(kwargs)
  93. def set_compiler_fn(self, compiler_fn: Callable) -> None:
  94. """Set a function to be used for compiling the module.
  95. This function should take a torch.nn.Module as input and return a compiled module.
  96. Note that other compile options are ignored when a compiler_fn is set.
  97. Example:
  98. ```python
  99. def my_compiler_fn(module: torch.nn.Module):
  100. ...
  101. return torch.compile(module, ...)
  102. engine.set_compiler_fn(my_compiler_fn)
  103. ```
  104. """
  105. self._compiler_fn = compiler_fn
  106. def forward(self, *args, **kwargs) -> Any:
  107. if not self.is_compiled:
  108. if self._compiler_fn is None:
  109. self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs)
  110. else:
  111. self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
  112. self._is_compiled = True
  113. return self.__dict__['wrapped'](*args, **kwargs)
  114. @property
  115. def is_compiled(self) -> bool:
  116. return self._is_compiled
  117. @property
  118. def backend(self) -> Union[str, Callable]:
  119. return self._backend
  120. @property
  121. def torch_compile_kwargs(self) -> Dict[str, Any]:
  122. return self._compile_kwargs
  123. @property
  124. def compiler_fn(self) -> Union[Callable, None]:
  125. return self._compiler_fn