__init__.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import sys
  5. import types
  6. import json
  7. from typing import Optional, Union
  8. import torch
  9. from torch.optim import Optimizer
  10. from torch.optim.lr_scheduler import _LRScheduler
  11. from packaging import version as pkg_version
  12. # Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed
  13. if not (hasattr(torch.version, 'hip') and torch.version.hip is not None):
  14. try:
  15. import triton # noqa: F401 # type: ignore
  16. HAS_TRITON = True
  17. except ImportError:
  18. HAS_TRITON = False
  19. else:
  20. HAS_TRITON = False
  21. from . import ops
  22. from . import module_inject
  23. from .accelerator import get_accelerator
  24. from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
  25. from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
  26. from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
  27. from .runtime.hybrid_engine import DeepSpeedHybridEngine
  28. from .runtime.pipe.engine import PipelineEngine
  29. from .inference.engine import InferenceEngine
  30. from .inference.config import DeepSpeedInferenceConfig
  31. from .runtime.lr_schedules import add_tuning_arguments
  32. from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
  33. from .runtime.activation_checkpointing import checkpointing
  34. from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
  35. from .module_inject import replace_transformer_layer, revert_transformer_layer
  36. from .utils import log_dist, OnDevice, logger
  37. from .comm.comm import init_distributed
  38. from .runtime import zero
  39. from .runtime.compiler import is_compile_supported
  40. from .pipe import PipelineModule
  41. from .git_version_info import version, git_hash, git_branch
  42. def _parse_version(version_str):
  43. '''Parse a version string and extract the major, minor, and patch versions.'''
  44. ver = pkg_version.parse(version_str)
  45. return ver.major, ver.minor, ver.micro
  46. # Export version information
  47. __version__ = version
  48. __version_major__, __version_minor__, __version_patch__ = _parse_version(__version__)
  49. __git_hash__ = git_hash
  50. __git_branch__ = git_branch
  51. # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
  52. dist = None
  53. def initialize(args=None,
  54. model: torch.nn.Module = None,
  55. optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
  56. model_parameters: Optional[torch.nn.Module] = None,
  57. training_data: Optional[torch.utils.data.Dataset] = None,
  58. lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
  59. distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
  60. mpu=None,
  61. dist_init_required: Optional[bool] = None,
  62. collate_fn=None,
  63. config=None,
  64. config_params=None):
  65. """Initialize the DeepSpeed Engine.
  66. Arguments:
  67. args: an object containing local_rank and deepspeed_config fields.
  68. This is optional if `config` is passed.
  69. model: Required: nn.module class before apply any wrappers
  70. optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.
  71. This overrides any optimizer definition in the DeepSpeed json config.
  72. model_parameters: Optional: An iterable of torch.Tensors or dicts.
  73. Specifies what Tensors should be optimized.
  74. training_data: Optional: Dataset of type torch.utils.data.Dataset
  75. lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
  76. The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
  77. distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training
  78. mpu: Optional: A model parallelism unit object that implements
  79. get_{model,data}_parallel_{rank,group,world_size}()
  80. dist_init_required: Optional: None will auto-initialize torch distributed if needed,
  81. otherwise the user can force it to be initialized or not via boolean.
  82. collate_fn: Optional: Merges a list of samples to form a
  83. mini-batch of Tensor(s). Used when using batched loading from a
  84. map-style dataset.
  85. config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config
  86. as an argument instead, as a path or a dictionary.
  87. config_params: Optional: Same as `config`, kept for backwards compatibility.
  88. Returns:
  89. A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``
  90. * ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.
  91. * ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if
  92. optimizer is specified in json config else ``None``.
  93. * ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,
  94. otherwise ``None``.
  95. * ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
  96. if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
  97. """
  98. log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__,
  99. __git_branch__),
  100. ranks=[0])
  101. # Disable zero.Init context if it's currently enabled
  102. zero.partition_parameters.shutdown_init_context()
  103. assert model is not None, "deepspeed.initialize requires a model"
  104. global dist
  105. from deepspeed import comm as dist
  106. dist_backend = get_accelerator().communication_backend_name()
  107. dist.init_distributed(dist_backend=dist_backend,
  108. distributed_port=distributed_port,
  109. dist_init_required=dist_init_required)
  110. # Set config using config_params for backwards compat
  111. if config is None and config_params is not None:
  112. config = config_params
  113. # Check for deepscale_config for backwards compat
  114. if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
  115. logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
  116. if hasattr(args, "deepspeed_config"):
  117. assert (args.deepspeed_config is
  118. None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
  119. args.deepspeed_config = args.deepscale_config
  120. args.deepscale_config = None
  121. # Check that we have only one config passed
  122. if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None:
  123. assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
  124. config = args.deepspeed_config
  125. assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
  126. if not isinstance(model, PipelineModule):
  127. config_class = DeepSpeedConfig(config, mpu)
  128. if config_class.hybrid_engine.enabled:
  129. engine = DeepSpeedHybridEngine(args=args,
  130. model=model,
  131. optimizer=optimizer,
  132. model_parameters=model_parameters,
  133. training_data=training_data,
  134. lr_scheduler=lr_scheduler,
  135. mpu=mpu,
  136. dist_init_required=dist_init_required,
  137. collate_fn=collate_fn,
  138. config=config,
  139. config_class=config_class)
  140. else:
  141. engine = DeepSpeedEngine(args=args,
  142. model=model,
  143. optimizer=optimizer,
  144. model_parameters=model_parameters,
  145. training_data=training_data,
  146. lr_scheduler=lr_scheduler,
  147. mpu=mpu,
  148. dist_init_required=dist_init_required,
  149. collate_fn=collate_fn,
  150. config=config,
  151. config_class=config_class)
  152. else:
  153. assert mpu is None, "mpu must be None with pipeline parallelism"
  154. mpu = model.mpu()
  155. config_class = DeepSpeedConfig(config, mpu)
  156. engine = PipelineEngine(args=args,
  157. model=model,
  158. optimizer=optimizer,
  159. model_parameters=model_parameters,
  160. training_data=training_data,
  161. lr_scheduler=lr_scheduler,
  162. mpu=mpu,
  163. dist_init_required=dist_init_required,
  164. collate_fn=collate_fn,
  165. config=config,
  166. config_class=config_class)
  167. # Restore zero.Init context if necessary
  168. zero.partition_parameters.restore_init_context()
  169. return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
  170. return tuple(return_items)
  171. def _add_core_arguments(parser):
  172. r"""Helper (internal) function to update an argument parser with an argument group of the core DeepSpeed arguments.
  173. The core set of DeepSpeed arguments include the following:
  174. 1) --deepspeed: boolean flag to enable DeepSpeed
  175. 2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
  176. This is a helper function to the public add_config_arguments()
  177. Arguments:
  178. parser: argument parser
  179. Return:
  180. parser: Updated Parser
  181. """
  182. group = parser.add_argument_group('DeepSpeed', 'DeepSpeed configurations')
  183. group.add_argument('--deepspeed',
  184. default=False,
  185. action='store_true',
  186. help='Enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
  187. group.add_argument('--deepspeed_config', default=None, type=str, help='DeepSpeed json configuration file.')
  188. group.add_argument('--deepscale',
  189. default=False,
  190. action='store_true',
  191. help='Deprecated enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
  192. group.add_argument('--deepscale_config',
  193. default=None,
  194. type=str,
  195. help='Deprecated DeepSpeed json configuration file.')
  196. return parser
  197. def add_config_arguments(parser):
  198. r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments.
  199. The set of DeepSpeed arguments include the following:
  200. 1) --deepspeed: boolean flag to enable DeepSpeed
  201. 2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
  202. Arguments:
  203. parser: argument parser
  204. Return:
  205. parser: Updated Parser
  206. """
  207. parser = _add_core_arguments(parser)
  208. return parser
  209. def default_inference_config():
  210. """
  211. Return a default DeepSpeed inference configuration dictionary.
  212. """
  213. return DeepSpeedInferenceConfig().dict()
  214. def init_inference(model, config=None, **kwargs):
  215. """Initialize the DeepSpeed InferenceEngine.
  216. Description: all four cases are valid and supported in DS init_inference() API.
  217. # Case 1: user provides no config and no kwargs. Default config will be used.
  218. .. code-block:: python
  219. generator.model = deepspeed.init_inference(generator.model)
  220. string = generator("DeepSpeed is")
  221. print(string)
  222. # Case 2: user provides a config and no kwargs. User supplied config will be used.
  223. .. code-block:: python
  224. generator.model = deepspeed.init_inference(generator.model, config=config)
  225. string = generator("DeepSpeed is")
  226. print(string)
  227. # Case 3: user provides no config and uses keyword arguments (kwargs) only.
  228. .. code-block:: python
  229. generator.model = deepspeed.init_inference(generator.model,
  230. tensor_parallel={"tp_size": world_size},
  231. dtype=torch.half,
  232. replace_with_kernel_inject=True)
  233. string = generator("DeepSpeed is")
  234. print(string)
  235. # Case 4: user provides config and keyword arguments (kwargs). Both config and kwargs are merged and kwargs take precedence.
  236. .. code-block:: python
  237. generator.model = deepspeed.init_inference(generator.model, config={"dtype": torch.half}, replace_with_kernel_inject=True)
  238. string = generator("DeepSpeed is")
  239. print(string)
  240. Arguments:
  241. model: Required: original nn.module object without any wrappers
  242. config: Optional: instead of arguments, you can pass in a DS inference config dict or path to JSON file
  243. Returns:
  244. A deepspeed.InferenceEngine wrapped model.
  245. """
  246. log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__,
  247. __git_branch__),
  248. ranks=[0])
  249. # Load config_dict from config first
  250. if config is None:
  251. config = {}
  252. if isinstance(config, str):
  253. with open(config, "r") as f:
  254. config_dict = json.load(f)
  255. elif isinstance(config, dict):
  256. config_dict = config
  257. else:
  258. raise ValueError(f"'config' argument expected string or dictionary, got {type(config)}")
  259. # Update with values from kwargs, ensuring no conflicting overlap between config and kwargs
  260. overlap_keys = set(config_dict.keys()).intersection(kwargs.keys())
  261. # If there is overlap, error out if values are different
  262. for key in overlap_keys:
  263. if config_dict[key] != kwargs[key]:
  264. raise ValueError(f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}")
  265. config_dict.update(kwargs)
  266. ds_inference_config = DeepSpeedInferenceConfig(**config_dict)
  267. engine = InferenceEngine(model, config=ds_inference_config)
  268. return engine