__init__.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import sys
  5. import types
  6. import json
  7. from typing import Optional, Union
  8. import torch
  9. from torch.optim import Optimizer
  10. from torch.optim.lr_scheduler import _LRScheduler
  11. from packaging import version as pkg_version
  12. # Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed
  13. if not (hasattr(torch.version, 'hip') and torch.version.hip is not None):
  14. try:
  15. import triton # noqa: F401 # type: ignore
  16. HAS_TRITON = True
  17. except ImportError:
  18. HAS_TRITON = False
  19. else:
  20. HAS_TRITON = False
  21. from . import ops
  22. from . import module_inject
  23. from .accelerator import get_accelerator
  24. from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
  25. from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
  26. from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
  27. from .runtime.hybrid_engine import DeepSpeedHybridEngine
  28. from .runtime.pipe.engine import PipelineEngine
  29. from .inference.engine import InferenceEngine
  30. from .inference.config import DeepSpeedInferenceConfig
  31. from .runtime.lr_schedules import add_tuning_arguments
  32. from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
  33. from .runtime.activation_checkpointing import checkpointing
  34. from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
  35. from .module_inject import replace_transformer_layer, revert_transformer_layer
  36. from .utils import log_dist, OnDevice, logger
  37. from .comm.comm import init_distributed
  38. from .runtime import zero
  39. from .runtime.compiler import is_compile_supported
  40. from .pipe import PipelineModule
  41. from .git_version_info import version, git_hash, git_branch
  42. def _parse_version(version_str):
  43. '''Parse a version string and extract the major, minor, and patch versions.'''
  44. ver = pkg_version.parse(version_str)
  45. return ver.major, ver.minor, ver.micro
  46. # Export version information
  47. __version__ = version
  48. __version_major__, __version_minor__, __version_patch__ = _parse_version(__version__)
  49. __git_hash__ = git_hash
  50. __git_branch__ = git_branch
  51. # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
  52. dist = None
  53. def initialize(args=None,
  54. model: torch.nn.Module = None,
  55. optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
  56. model_parameters: Optional[torch.nn.Module] = None,
  57. training_data: Optional[torch.utils.data.Dataset] = None,
  58. lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
  59. distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
  60. mpu=None,
  61. dist_init_required: Optional[bool] = None,
  62. collate_fn=None,
  63. config=None,
  64. mesh_param=None,
  65. config_params=None):
  66. """Initialize the DeepSpeed Engine.
  67. Arguments:
  68. args: an object containing local_rank and deepspeed_config fields.
  69. This is optional if `config` is passed.
  70. model: Required: nn.module class before apply any wrappers
  71. optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.
  72. This overrides any optimizer definition in the DeepSpeed json config.
  73. model_parameters: Optional: An iterable of torch.Tensors or dicts.
  74. Specifies what Tensors should be optimized.
  75. training_data: Optional: Dataset of type torch.utils.data.Dataset
  76. lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
  77. The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
  78. distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training
  79. mpu: Optional: A model parallelism unit object that implements
  80. get_{model,data}_parallel_{rank,group,world_size}()
  81. dist_init_required: Optional: None will auto-initialize torch distributed if needed,
  82. otherwise the user can force it to be initialized or not via boolean.
  83. collate_fn: Optional: Merges a list of samples to form a
  84. mini-batch of Tensor(s). Used when using batched loading from a
  85. map-style dataset.
  86. config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config
  87. as an argument instead, as a path or a dictionary.
  88. config_params: Optional: Same as `config`, kept for backwards compatibility.
  89. Returns:
  90. A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``
  91. * ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.
  92. * ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if
  93. optimizer is specified in json config else ``None``.
  94. * ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,
  95. otherwise ``None``.
  96. * ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
  97. if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
  98. """
  99. log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__,
  100. __git_branch__),
  101. ranks=[0])
  102. # Disable zero.Init context if it's currently enabled
  103. zero.partition_parameters.shutdown_init_context()
  104. assert model is not None, "deepspeed.initialize requires a model"
  105. global dist
  106. from deepspeed import comm as dist
  107. dist_backend = get_accelerator().communication_backend_name()
  108. dist.init_distributed(dist_backend=dist_backend,
  109. distributed_port=distributed_port,
  110. dist_init_required=dist_init_required)
  111. ##TODO: combine reuse mpu as mesh device and vice versa
  112. # Set config using config_params for backwards compat
  113. if config is None and config_params is not None:
  114. config = config_params
  115. mesh_device = None
  116. if mesh_param:
  117. logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
  118. mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
  119. #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
  120. elif config is not None:
  121. if "sequence_parallel_size" in config and "data_parallel_size" in config:
  122. logger.info(f"config to Initialize mesh device: {config}")
  123. mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
  124. ("data_parallel", "sequence_parallel"))
  125. # Check for deepscale_config for backwards compat
  126. if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
  127. logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
  128. if hasattr(args, "deepspeed_config"):
  129. assert (args.deepspeed_config is
  130. None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
  131. args.deepspeed_config = args.deepscale_config
  132. args.deepscale_config = None
  133. # Check that we have only one config passed
  134. if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None:
  135. assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
  136. config = args.deepspeed_config
  137. assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
  138. if not isinstance(model, PipelineModule):
  139. config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
  140. if config_class.hybrid_engine.enabled:
  141. engine = DeepSpeedHybridEngine(args=args,
  142. model=model,
  143. optimizer=optimizer,
  144. model_parameters=model_parameters,
  145. training_data=training_data,
  146. lr_scheduler=lr_scheduler,
  147. mpu=mpu,
  148. dist_init_required=dist_init_required,
  149. collate_fn=collate_fn,
  150. config=config,
  151. config_class=config_class)
  152. else:
  153. engine = DeepSpeedEngine(args=args,
  154. model=model,
  155. optimizer=optimizer,
  156. model_parameters=model_parameters,
  157. training_data=training_data,
  158. lr_scheduler=lr_scheduler,
  159. mpu=mpu,
  160. dist_init_required=dist_init_required,
  161. collate_fn=collate_fn,
  162. config=config,
  163. mesh_device=mesh_device,
  164. config_class=config_class)
  165. else:
  166. assert mpu is None, "mpu must be None with pipeline parallelism"
  167. mpu = model.mpu()
  168. config_class = DeepSpeedConfig(config, mpu)
  169. engine = PipelineEngine(args=args,
  170. model=model,
  171. optimizer=optimizer,
  172. model_parameters=model_parameters,
  173. training_data=training_data,
  174. lr_scheduler=lr_scheduler,
  175. mpu=mpu,
  176. dist_init_required=dist_init_required,
  177. collate_fn=collate_fn,
  178. config=config,
  179. config_class=config_class)
  180. # Restore zero.Init context if necessary
  181. zero.partition_parameters.restore_init_context()
  182. return_items = [
  183. engine,
  184. engine.optimizer,
  185. engine.training_dataloader,
  186. engine.lr_scheduler,
  187. ]
  188. return tuple(return_items)
  189. def _add_core_arguments(parser):
  190. r"""Helper (internal) function to update an argument parser with an argument group of the core DeepSpeed arguments.
  191. The core set of DeepSpeed arguments include the following:
  192. 1) --deepspeed: boolean flag to enable DeepSpeed
  193. 2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
  194. This is a helper function to the public add_config_arguments()
  195. Arguments:
  196. parser: argument parser
  197. Return:
  198. parser: Updated Parser
  199. """
  200. group = parser.add_argument_group('DeepSpeed', 'DeepSpeed configurations')
  201. group.add_argument('--deepspeed',
  202. default=False,
  203. action='store_true',
  204. help='Enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
  205. group.add_argument('--deepspeed_config', default=None, type=str, help='DeepSpeed json configuration file.')
  206. group.add_argument('--deepscale',
  207. default=False,
  208. action='store_true',
  209. help='Deprecated enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
  210. group.add_argument('--deepscale_config',
  211. default=None,
  212. type=str,
  213. help='Deprecated DeepSpeed json configuration file.')
  214. return parser
  215. def add_config_arguments(parser):
  216. r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments.
  217. The set of DeepSpeed arguments include the following:
  218. 1) --deepspeed: boolean flag to enable DeepSpeed
  219. 2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
  220. Arguments:
  221. parser: argument parser
  222. Return:
  223. parser: Updated Parser
  224. """
  225. parser = _add_core_arguments(parser)
  226. return parser
  227. def default_inference_config():
  228. """
  229. Return a default DeepSpeed inference configuration dictionary.
  230. """
  231. return DeepSpeedInferenceConfig().dict()
  232. def init_inference(model, config=None, **kwargs):
  233. """Initialize the DeepSpeed InferenceEngine.
  234. Description: all four cases are valid and supported in DS init_inference() API.
  235. # Case 1: user provides no config and no kwargs. Default config will be used.
  236. .. code-block:: python
  237. generator.model = deepspeed.init_inference(generator.model)
  238. string = generator("DeepSpeed is")
  239. print(string)
  240. # Case 2: user provides a config and no kwargs. User supplied config will be used.
  241. .. code-block:: python
  242. generator.model = deepspeed.init_inference(generator.model, config=config)
  243. string = generator("DeepSpeed is")
  244. print(string)
  245. # Case 3: user provides no config and uses keyword arguments (kwargs) only.
  246. .. code-block:: python
  247. generator.model = deepspeed.init_inference(generator.model,
  248. tensor_parallel={"tp_size": world_size},
  249. dtype=torch.half,
  250. replace_with_kernel_inject=True)
  251. string = generator("DeepSpeed is")
  252. print(string)
  253. # Case 4: user provides config and keyword arguments (kwargs). Both config and kwargs are merged and kwargs take precedence.
  254. .. code-block:: python
  255. generator.model = deepspeed.init_inference(generator.model, config={"dtype": torch.half}, replace_with_kernel_inject=True)
  256. string = generator("DeepSpeed is")
  257. print(string)
  258. Arguments:
  259. model: Required: original nn.module object without any wrappers
  260. config: Optional: instead of arguments, you can pass in a DS inference config dict or path to JSON file
  261. Returns:
  262. A deepspeed.InferenceEngine wrapped model.
  263. """
  264. log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__,
  265. __git_branch__),
  266. ranks=[0])
  267. # Load config_dict from config first
  268. if config is None:
  269. config = {}
  270. if isinstance(config, str):
  271. with open(config, "r") as f:
  272. config_dict = json.load(f)
  273. elif isinstance(config, dict):
  274. config_dict = config
  275. else:
  276. raise ValueError(f"'config' argument expected string or dictionary, got {type(config)}")
  277. # Update with values from kwargs, ensuring no conflicting overlap between config and kwargs
  278. overlap_keys = set(config_dict.keys()).intersection(kwargs.keys())
  279. # If there is overlap, error out if values are different
  280. for key in overlap_keys:
  281. if config_dict[key] != kwargs[key]:
  282. raise ValueError(f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}")
  283. config_dict.update(kwargs)
  284. ds_inference_config = DeepSpeedInferenceConfig(**config_dict)
  285. engine = InferenceEngine(model, config=ds_inference_config)
  286. return engine