mics.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  5. # SPDX-License-Identifier: Apache-2.0
  6. import sys
  7. from typing import List
  8. import deepspeed
  9. import torch
  10. from deepspeed import comm as dist
  11. from deepspeed.runtime.zero.utils import is_zero_param
  12. from deepspeed.runtime.zero.mics_utils import (MiCS_CommGroups, create_mics_comm_groups, scale_tensors)
  13. from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
  14. from deepspeed.runtime.zero.partition_parameters import Init, AllGatherCoalescedHandle, ZeroParamStatus
  15. from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
  16. from deepspeed.utils import instrument_w_nvtx, log_dist
  17. from deepspeed.accelerator import get_accelerator
  18. from torch import Tensor
  19. from torch.nn import Parameter
  20. def has_hierarchical_all_gather_groups(comm_groups: MiCS_CommGroups):
  21. result = False
  22. if comm_groups.param_intra_node_group is not None and comm_groups.param_inter_node_shard_group is not None:
  23. result = True
  24. return result
  25. class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle):
  26. """ This handle assumes that no need to
  27. copy data out from a contiguous tensor
  28. """
  29. def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None:
  30. super().__init__(allgather_handle, params, partitions, world_size)
  31. def wait(self) -> None:
  32. """
  33. """
  34. # let the current stream to op
  35. try:
  36. print("HANDLE", self.allgather_handle)
  37. instrument_w_nvtx(self.allgather_handle.wait)()
  38. except (ValueError, RuntimeError) as e:
  39. log_dist(
  40. f"WARNING: Runtime Error while waiting the collective all-gather, possibly due to the _IllegalWork",
  41. ranks=[0])
  42. log_dist(f"Error message: {e}", ranks=[0])
  43. if self.complete:
  44. return
  45. for _, param in enumerate(self.params):
  46. assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
  47. param.ds_status = ZeroParamStatus.AVAILABLE
  48. self.complete = True
  49. class MiCS_Init(Init):
  50. def __init__(self,
  51. module=None,
  52. data_parallel_group=None,
  53. sequence_data_parallel_group=None,
  54. mem_efficient_linear=True,
  55. remote_device=None,
  56. pin_memory=False,
  57. config_dict_or_path=None,
  58. config=None,
  59. enabled=True,
  60. dtype=None,
  61. mpu=None):
  62. """A context manager to partition the model parameters during the model
  63. construction with MiCS partition strategy. Model states are partitioned
  64. to the number of devices specified via ``mics_shard_size`` field in the
  65. deepspeed config json file. The context manager also introduces
  66. hierarchical communication method to reduce the cost of inter-node
  67. communications, which can be enabled with
  68. ``mics_hierarchical_params_gather`` field in deepspeed config.
  69. Args:
  70. module (``torch.nn.Module``, optional): If provided, partition the model as
  71. if it was constructed in the context.
  72. data_parallel_group (``deepspeed.comm`` process group, optional):
  73. The group of processes to partition among. Defaults to all processes.
  74. mem_efficient_linear (bool, optional): Replace
  75. torch.nn.functional.linear with an implementation that allows
  76. DeepSpeed to partition parameters. Defaults to ``True``.
  77. remote_device (string, optional): The initial device to store model
  78. weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
  79. memory. The model may still be moved to GPU based on the
  80. offload settings for training. Defaults to param offload device if a config is
  81. defined, otherwise GPU.
  82. pin_memory (bool, optional): Potentially increase performance by
  83. using pinned memory for model weights. ``remote_device`` must be
  84. ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
  85. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
  86. for swapping fp16 params to NVMe.
  87. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
  88. enabled (bool, optional): If ``False``, this context has no
  89. effect. Defaults to ``True``.
  90. dtype (``dtype``, optional): Can be used to change the data type of the parameters.
  91. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
  92. mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
  93. This context follows the same logic as ``deepspeed.zero.Init()``, but
  94. with the modification for partition size of each parameter.
  95. Examples
  96. --------
  97. #. Allocate a model and partition it among all processes:
  98. .. code-block:: python
  99. # the config_dict_or_path is required to let the context manager know
  100. # how partition the parameters.
  101. # The configuration has to include the field ``mics_shard_size``
  102. with deepspeed.zero.MiCS_Init(config_dict_or_path=ds_config):
  103. model = MyLargeModel()
  104. #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
  105. .. code-block:: python
  106. with deepspeed.zero.MiCS_Init(data_parallel_group=mpu.get_data_parallel_group(),
  107. remote_device="cpu",
  108. pin_memory=True
  109. config_dict_or_path=ds_config):
  110. model = MyLargeModel()
  111. #. Partition an already-allocated model in CPU memory:
  112. .. code-block:: python
  113. model = deepspeed.zero.MiCS_Init(module=model,
  114. config_dict_or_path=ds_config)
  115. """
  116. assert config_dict_or_path is not None, "Must provide configuration for MiCS Initialization"
  117. _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu)
  118. if not dist.is_initialized():
  119. dist.init_distributed()
  120. assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
  121. if data_parallel_group is None and sequence_data_parallel_group is None:
  122. ds_process_group = dist.get_world_group()
  123. elif sequence_data_parallel_group is not None:
  124. ds_process_group = sequence_data_parallel_group
  125. elif data_parallel_group is not None:
  126. ds_process_group = data_parallel_group
  127. else: # both given
  128. raise ValueError(
  129. "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments."
  130. )
  131. self.mics_comm_groups = create_mics_comm_groups(
  132. _ds_config.mics_shard_size,
  133. ds_process_group,
  134. hierarchical_allgather=_ds_config.mics_hierarchial_params_gather,
  135. mpu=mpu)
  136. super().__init__(module, data_parallel_group, mem_efficient_linear, remote_device, pin_memory,
  137. config_dict_or_path, config, enabled, dtype, mpu)
  138. def _convert_to_deepspeed_param(self, param):
  139. super()._convert_to_deepspeed_param(param)
  140. # attach communication groups to every param
  141. param.comm = self.mics_comm_groups
  142. # record existing all_gather_coalesced implementation
  143. # so that we can fallback later
  144. old_all_gather_coalesced = param.all_gather_coalesced
  145. def _param_all_gather_coalesced(params, param_buffers=None, **kwargs):
  146. """"""
  147. mics_comm_groups: MiCS_CommGroups = params[0].comm
  148. hierarchical_all_gather = has_hierarchical_all_gather_groups(mics_comm_groups)
  149. if dist.has_coalescing_manager() and hierarchical_all_gather:
  150. return self._hierarchical_all_gather_params(params, param_buffers)
  151. elif dist.has_coalescing_manager():
  152. return self._flat_all_gather_with_coalescing_manager(params, param_buffers)
  153. else:
  154. return old_all_gather_coalesced(params, **kwargs)
  155. # change the all_gather_coalesced method
  156. param.all_gather_coalesced = _param_all_gather_coalesced
  157. def _pre_all_gather(self, params, params_buffers=None):
  158. # fetches from nvme if the partition is not available and in nvme
  159. self._ensure_availability_of_partitioned_params(params)
  160. for param in params:
  161. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  162. raise RuntimeError(param.ds_summary())
  163. param.ds_status = ZeroParamStatus.INFLIGHT
  164. # ensure that each rank has params in same order. the allgather
  165. # is done by flattening the parameter list into a single tensor that
  166. # can be allgathered in a single call - this means that if each rank
  167. # gives a list of the same parameters in a different order we will
  168. # silently get incorrect parameter values, and have very difficult
  169. # to debug correctness issues.
  170. params = sorted(params, key=lambda p: p.ds_id)
  171. return params, params_buffers
  172. def _flat_all_gather_with_coalescing_manager(self, params, params_buffers=None):
  173. """"""
  174. # must have to change the status of the param
  175. # and ensure they are on the device
  176. params, params_buffers = self._pre_all_gather(params, params_buffers)
  177. mics_comm_groups: MiCS_CommGroups = params[0].comm
  178. param_shard_size = mics_comm_groups.param_shard_size
  179. output_tensors = []
  180. input_tensors = []
  181. for i, p in enumerate(params):
  182. t_size = p.ds_tensor.ds_numel * param_shard_size
  183. if params_buffers is not None and params_buffers[i] is not None:
  184. assert params_buffers[i].numel(
  185. ) == t_size, f'params_to_gather_buffers[{i}] size {params_buffers[i].numel()} does not match with t_size {t_size}'
  186. flat_out = params_buffers[i]
  187. else:
  188. flat_out = torch.empty(t_size, dtype=p.dtype, device=self.local_device, requires_grad=False).view(-1)
  189. output_tensors.append(flat_out)
  190. _flat_input = p.ds_tensor.data.view(-1)
  191. input_tensors.append(_flat_input)
  192. all_gather_handle = dist.all_gather_coalesced(output_tensors,
  193. input_tensors,
  194. group=mics_comm_groups.param_shard_group,
  195. async_op=True)
  196. for idx, param in enumerate(params):
  197. param.data = output_tensors[idx].narrow(0, 0, param.ds_numel).view(param.ds_shape).data
  198. return MiCS_AllGatherCoalescedHandle(allgather_handle=all_gather_handle,
  199. params=params,
  200. partitions=[],
  201. world_size=param_shard_size)
  202. def _hierarchical_all_gather_params(self, params, params_buffers=None):
  203. """"""
  204. params, params_buffers = self._pre_all_gather(params, params_buffers)
  205. mics_comm_groups: MiCS_CommGroups = params[0].comm
  206. local_rank = dist.get_rank(group=mics_comm_groups.param_intra_node_group)
  207. inter_node_comm_group = mics_comm_groups.param_inter_node_shard_group
  208. intra_node_comm_group = mics_comm_groups.param_intra_node_group
  209. param_shard_size = mics_comm_groups.param_shard_size
  210. inter_node_size = dist.get_world_size(group=inter_node_comm_group)
  211. intra_node_size = dist.get_world_size(group=intra_node_comm_group)
  212. param_tensors = []
  213. for i, p in enumerate(params):
  214. param_size = p.ds_tensor.ds_numel * param_shard_size
  215. if params_buffers is not None and params_buffers[i] is not None:
  216. assert params_buffers[i].numel(
  217. ) == param_size, f'param_buffers[{i}] size {params_buffers[i].numel()} does not match with param_size {param_size}'
  218. param_tensor = params_buffers[i]
  219. else:
  220. param_tensor = torch.empty(param_size, dtype=p.dtype, device=self.local_device,
  221. requires_grad=False).view(-1)
  222. param_tensors.append(param_tensor)
  223. # inter node all-gather
  224. inter_outputs = []
  225. inter_inputs = []
  226. for i, p in enumerate(params):
  227. inter_size = p.ds_tensor.ds_numel * inter_node_size
  228. _out = param_tensors[i].narrow(0, local_rank * inter_size, inter_size)
  229. inter_outputs.append(_out)
  230. inter_inputs.append(p.ds_tensor.data.view(-1).to(self.local_device))
  231. # sync enqueue
  232. dist.all_gather_coalesced(inter_outputs, inter_inputs, group=inter_node_comm_group, async_op=False)
  233. # intra node all-gather
  234. intra_outputs = []
  235. intra_inputs = []
  236. for i, p in enumerate(params):
  237. # partition param into multiple chunks for allgather
  238. # because inter-node all-gather outputs are in a continues memory
  239. # while in param memory, those inter-node data are placed in different
  240. # location.
  241. # each chunk is an intra-node output
  242. param_chunk = param_tensors[i].view(
  243. (inter_node_size, intra_node_size, p.ds_tensor.ds_numel)).narrow(1, local_rank, 1)
  244. param_chunk.copy_(inter_outputs[i].detach().clone().view(param_chunk.size()))
  245. output_chunks = torch.chunk(param_tensors[i], inter_node_size)
  246. for j, _out in enumerate(output_chunks):
  247. intra_chunk_size = intra_node_size * p.ds_tensor.ds_numel
  248. local_offset = local_rank * p.ds_tensor.ds_numel
  249. _in = param_tensors[i].narrow(0, j * intra_chunk_size + local_offset, p.ds_tensor.ds_numel)
  250. intra_outputs.append(_out)
  251. intra_inputs.append(_in)
  252. all_gather_handle = dist.all_gather_coalesced(intra_outputs,
  253. intra_inputs,
  254. group=intra_node_comm_group,
  255. async_op=True)
  256. for i, param in enumerate(params):
  257. param.data = param_tensors[i].narrow(0, 0, param.ds_numel).view(param.ds_shape).data
  258. return MiCS_AllGatherCoalescedHandle(
  259. allgather_handle=all_gather_handle,
  260. params=params,
  261. partitions=[],
  262. world_size=param_shard_size,
  263. )
  264. def get_partition_dp_group(self, param):
  265. return param.comm.param_shard_group
  266. def get_partition_rank(self):
  267. return self.mics_comm_groups.param_shard_rank
  268. @property
  269. def num_partitions(self):
  270. return self.mics_comm_groups.param_shard_size
  271. class MiCS_Offload(DeepSpeedZeRoOffload):
  272. """ Wrapper to change the behavior for parameter sharding
  273. """
  274. def _convert_to_zero_parameters(self, ds_config, module, mpu):
  275. """ overload the parent class function for convert the parameters
  276. """
  277. log_dist(f'Convert to zero parameters from MiCS Offload manager', ranks=[0])
  278. non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
  279. if non_zero_params:
  280. zero_params = [p for p in module.parameters() if is_zero_param(p)]
  281. if zero_params:
  282. zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
  283. else:
  284. group = None
  285. if mpu:
  286. group = mpu.get_data_parallel_group()
  287. MiCS_Init(module=module,
  288. data_parallel_group=group,
  289. dtype=self.dtype,
  290. config_dict_or_path=ds_config,
  291. remote_device=self.offload_device,
  292. pin_memory=self.offload_param_pin_memory,
  293. mpu=mpu)
  294. class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
  295. """
  296. MiCS Optimizer
  297. """
  298. def __init__(self,
  299. module,
  300. init_optimizer,
  301. timers,
  302. ds_config,
  303. static_loss_scale=1,
  304. dynamic_loss_scale=False,
  305. dynamic_loss_args=None,
  306. verbose=True,
  307. contiguous_gradients=True,
  308. reduce_bucket_size=500000000,
  309. prefetch_bucket_size=50000000,
  310. max_reuse_distance=1000000000,
  311. max_live_parameters=1000000000,
  312. param_persistence_threshold=100000,
  313. model_persistence_threshold=sys.maxsize,
  314. dp_process_group=None,
  315. reduce_scatter=True,
  316. overlap_comm=False,
  317. offload_optimizer_config=None,
  318. offload_param_config=None,
  319. sub_group_size=1000000000000,
  320. offload_ratio=0.0,
  321. mpu=None,
  322. clip_grad=0,
  323. gradient_accumulation_dtype=torch.float16,
  324. communication_data_type=torch.float16,
  325. postscale_gradients=True,
  326. gradient_predivide_factor=1,
  327. gradient_accumulation_steps=1,
  328. elastic_checkpoint=False,
  329. aio_config=None):
  330. log_dist("Init MiCS optimizer", ranks=[0])
  331. super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
  332. dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
  333. max_reuse_distance, max_live_parameters, param_persistence_threshold,
  334. model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
  335. offload_optimizer_config, offload_param_config, sub_group_size, offload_ratio, mpu, clip_grad,
  336. gradient_accumulation_dtype, communication_data_type, postscale_gradients,
  337. gradient_predivide_factor, gradient_accumulation_steps, elastic_checkpoint, aio_config)
  338. first_param = next(module.parameters())
  339. # overload the dp_process_group and partition_count
  340. assert hasattr(first_param, "comm"), " ".join([
  341. "Sharded parameters don't have the MiCS_CommGroups attached.",
  342. "Might due to the use of deepspeed.zero.Init context for initializing the weights.",
  343. "To use MiCS sharding, please use deepspeed.zero.MiCS_Init instead for initializing parameter."
  344. ])
  345. self.dp_process_group = first_param.comm.param_shard_group
  346. self.partition_count = first_param.comm.param_shard_size
  347. def initialize_ds_offload(
  348. self,
  349. *args,
  350. **kwargs,
  351. ):
  352. return MiCS_Offload(*args, **kwargs)
  353. def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
  354. grad_buffers = super().partition_grads(params_to_release, grad_partitions)
  355. # perform all-reduce among replication groups
  356. # the function will perform accumulation boundary check
  357. self.allreduce_mics_shard_grads(params_to_release, grad_buffers)
  358. @instrument_w_nvtx
  359. def allreduce_mics_shard_grads(self, params, partitioned_grads_buffers: List[Tensor]):
  360. """
  361. """
  362. # TODO: improve the condition check
  363. if not self.is_gradient_accumulation_boundary or \
  364. len(partitioned_grads_buffers) == 0:
  365. return
  366. mics_comm_groups: MiCS_CommGroups = params[0].comm
  367. param_repli_group = mics_comm_groups.param_repli_group
  368. param_repli_size = mics_comm_groups.param_repli_size
  369. if param_repli_size is None or param_repli_size <= 1:
  370. return
  371. if not get_accelerator().on_accelerator(partitioned_grads_buffers[0]):
  372. raise RuntimeError("Local sharding has no support for CPU offloading")
  373. if dist.has_all_reduce_coalesced():
  374. scale_tensors(partitioned_grads_buffers, param_repli_size)
  375. dist.all_reduce_coalesced(tensors=partitioned_grads_buffers, group=param_repli_group)
  376. else:
  377. # manually coalescing all-reduce
  378. aggregated_buffer: Tensor = torch.cat(partitioned_grads_buffers)
  379. aggregated_buffer.div_(param_repli_size)
  380. dist.all_reduce(aggregated_buffer, group=param_repli_group)
  381. offset = 0
  382. for grad_buff in partitioned_grads_buffers:
  383. grad_buff.view(-1).copy_(aggregated_buffer.narrow(0, offset, grad_buff.numel()))
  384. offset += grad_buff.numel()
  385. def load_state_dict(self,
  386. state_dict_list,
  387. load_optimizer_states=True,
  388. load_from_fp32_weights=False,
  389. checkpoint_folder=None,
  390. load_serial=None):
  391. r""" Loading the ZeRO-3/MiCS partitioned checkpoints
  392. Because the self.dp_process_group is replaced with the communicator for
  393. partition group we can call the load_state_dict logic from ZeRO-3.
  394. """
  395. super().load_state_dict(state_dict_list, load_optimizer_states, load_from_fp32_weights, checkpoint_folder)