mics.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  5. # SPDX-License-Identifier: Apache-2.0
  6. import sys
  7. from typing import List
  8. import deepspeed
  9. import torch
  10. from deepspeed import comm as dist
  11. from deepspeed.runtime.zero.mics_utils import (MiCS_CommGroups, create_mics_comm_groups, scale_tensors)
  12. from deepspeed.runtime.zero.parameter_offload import (DeepSpeedZeRoOffload, is_zero_param)
  13. from deepspeed.runtime.zero.partition_parameters import Init, AllGatherCoalescedHandle, ZeroParamStatus
  14. from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
  15. from deepspeed.utils import instrument_w_nvtx, log_dist
  16. from deepspeed.accelerator import get_accelerator
  17. from torch import Tensor
  18. from torch.nn import Parameter
  19. def has_hierarchical_all_gather_groups(comm_groups: MiCS_CommGroups):
  20. result = False
  21. if comm_groups.param_intra_node_group is not None and comm_groups.param_inter_node_shard_group is not None:
  22. result = True
  23. return result
  24. class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle):
  25. """ This handle assumes that no need to
  26. copy data out from a contiguous tensor
  27. """
  28. def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None:
  29. super().__init__(allgather_handle, params, partitions, world_size)
  30. def wait(self) -> None:
  31. """
  32. """
  33. # let the current stream to op
  34. instrument_w_nvtx(self.allgather_handle.wait)()
  35. if self.complete:
  36. return
  37. for _, param in enumerate(self.params):
  38. assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
  39. param.ds_status = ZeroParamStatus.AVAILABLE
  40. self.complete = True
  41. class MiCS_Init(Init):
  42. def __init__(self,
  43. module=None,
  44. data_parallel_group=None,
  45. mem_efficient_linear=True,
  46. remote_device=None,
  47. pin_memory=False,
  48. config_dict_or_path=None,
  49. config=None,
  50. enabled=True,
  51. dtype=None,
  52. mpu=None):
  53. """A context manager to partition the model parameters during the model
  54. construction with MiCS partition strategy. Model states are partitioned
  55. to the number of devices specified via ``mics_shard_size`` field in the
  56. deepspeed config json file. The context manager also introduces
  57. hierarchical communication method to reduce the cost of inter-node
  58. communications, which can be enabled with
  59. ``mics_hierarchical_params_gather`` field in deepspeed config.
  60. Args:
  61. module (``torch.nn.Module``, optional): If provided, partition the model as
  62. if it was constructed in the context.
  63. data_parallel_group (``deepspeed.comm`` process group, optional):
  64. The group of processes to partition among. Defaults to all processes.
  65. mem_efficient_linear (bool, optional): Replace
  66. torch.nn.functional.linear with an implementation that allows
  67. DeepSpeed to partition parameters. Defaults to ``True``.
  68. remote_device (string, optional): The initial device to store model
  69. weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
  70. memory. The model may still be moved to GPU based on the
  71. offload settings for training. Defaults to param offload device if a config is
  72. defined, otherwise GPU.
  73. pin_memory (bool, optional): Potentially increase performance by
  74. using pinned memory for model weights. ``remote_device`` must be
  75. ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
  76. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
  77. for swapping fp16 params to NVMe.
  78. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
  79. enabled (bool, optional): If ``False``, this context has no
  80. effect. Defaults to ``True``.
  81. dtype (``dtype``, optional): Can be used to change the data type of the parameters.
  82. Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
  83. mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
  84. This context follows the same logic as ``deepspeed.zero.Init()``, but
  85. with the modification for partition size of each parameter.
  86. Examples
  87. --------
  88. #. Allocate a model and partition it among all processes:
  89. .. code-block:: python
  90. # the config_dict_or_path is required to let the context manager know
  91. # how partition the parameters.
  92. # The configuration has to include the field ``mics_shard_size``
  93. with deepspeed.zero.MiCS_Init(config_dict_or_path=ds_config):
  94. model = MyLargeModel()
  95. #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
  96. .. code-block:: python
  97. with deepspeed.zero.MiCS_Init(data_parallel_group=mpu.get_data_parallel_group(),
  98. remote_device="cpu",
  99. pin_memory=True
  100. config_dict_or_path=ds_config):
  101. model = MyLargeModel()
  102. #. Partition an already-allocated model in CPU memory:
  103. .. code-block:: python
  104. model = deepspeed.zero.MiCS_Init(module=model,
  105. config_dict_or_path=ds_config)
  106. """
  107. assert config_dict_or_path is not None, "Must provide configuration for MiCS Initialization"
  108. _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu)
  109. if not dist.is_initialized():
  110. dist.init_distributed()
  111. assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
  112. self.mics_comm_groups = create_mics_comm_groups(
  113. _ds_config.mics_shard_size,
  114. data_parallel_group,
  115. hierarchical_allgather=_ds_config.mics_hierarchial_params_gather,
  116. mpu=mpu)
  117. super().__init__(module, data_parallel_group, mem_efficient_linear, remote_device, pin_memory,
  118. config_dict_or_path, config, enabled, dtype, mpu)
  119. def _convert_to_deepspeed_param(self, param):
  120. super()._convert_to_deepspeed_param(param)
  121. # attach communication groups to every param
  122. param.comm = self.mics_comm_groups
  123. # record existing all_gather_coalesced implementation
  124. # so that we can fallback later
  125. old_all_gather_coalesced = param.all_gather_coalesced
  126. def _param_all_gather_coalesced(params, safe_mode=False, param_buffers=None):
  127. """"""
  128. mics_comm_groups: MiCS_CommGroups = params[0].comm
  129. hierarchical_all_gather = has_hierarchical_all_gather_groups(mics_comm_groups)
  130. if dist.has_coalescing_manager() and hierarchical_all_gather:
  131. return self._hierarchical_all_gather_params(params, param_buffers)
  132. elif dist.has_coalescing_manager():
  133. return self._flat_all_gather_with_coalescing_manager(params, param_buffers)
  134. else:
  135. return old_all_gather_coalesced(params, safe_mode)
  136. # change the all_gather_coalesced method
  137. param.all_gather_coalesced = _param_all_gather_coalesced
  138. def _pre_all_gather(self, params, params_buffers=None):
  139. # fetches from nvme if the partition is not available and in nvme
  140. self._ensure_availability_of_partitioned_params(params)
  141. for param in params:
  142. if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
  143. raise RuntimeError(param.ds_summary())
  144. param.ds_status = ZeroParamStatus.INFLIGHT
  145. # ensure that each rank has params in same order. the allgather
  146. # is done by flattening the parameter list into a single tensor that
  147. # can be allgathered in a single call - this means that if each rank
  148. # gives a list of the same parameters in a different order we will
  149. # silently get incorrect parameter values, and have very difficult
  150. # to debug correctness issues.
  151. params = sorted(params, key=lambda p: p.ds_id)
  152. return params, params_buffers
  153. def _flat_all_gather_with_coalescing_manager(self, params, params_buffers=None):
  154. """"""
  155. # must have to change the status of the param
  156. # and ensure they are on the device
  157. params, params_buffers = self._pre_all_gather(params, params_buffers)
  158. mics_comm_groups: MiCS_CommGroups = params[0].comm
  159. param_shard_size = mics_comm_groups.param_shard_size
  160. output_tensors = []
  161. input_tensors = []
  162. for i, p in enumerate(params):
  163. t_size = p.ds_tensor.ds_numel * param_shard_size
  164. if params_buffers is not None and params_buffers[i] is not None:
  165. assert params_buffers[i].numel(
  166. ) == t_size, f'params_to_gather_buffers[{i}] size {params_buffers[i].numel()} does not match with t_size {t_size}'
  167. flat_out = params_buffers[i]
  168. else:
  169. flat_out = torch.empty(t_size, dtype=p.dtype, device=self.local_device, requires_grad=False).view(-1)
  170. output_tensors.append(flat_out)
  171. _flat_input = p.ds_tensor.data.view(-1)
  172. input_tensors.append(_flat_input)
  173. all_gather_handle = dist.all_gather_coalesced(output_tensors,
  174. input_tensors,
  175. group=mics_comm_groups.param_shard_group,
  176. async_op=True)
  177. for idx, param in enumerate(params):
  178. param.data = output_tensors[idx].narrow(0, 0, param.ds_numel).view(param.ds_shape).data
  179. return MiCS_AllGatherCoalescedHandle(allgather_handle=all_gather_handle,
  180. params=params,
  181. partitions=[],
  182. world_size=param_shard_size)
  183. def _hierarchical_all_gather_params(self, params, params_buffers=None):
  184. """"""
  185. params, params_buffers = self._pre_all_gather(params, params_buffers)
  186. mics_comm_groups: MiCS_CommGroups = params[0].comm
  187. local_rank = dist.get_rank(group=mics_comm_groups.param_intra_node_group)
  188. inter_node_comm_group = mics_comm_groups.param_inter_node_shard_group
  189. intra_node_comm_group = mics_comm_groups.param_intra_node_group
  190. param_shard_size = mics_comm_groups.param_shard_size
  191. inter_node_size = dist.get_world_size(group=inter_node_comm_group)
  192. intra_node_size = dist.get_world_size(group=intra_node_comm_group)
  193. param_tensors = []
  194. for i, p in enumerate(params):
  195. param_size = p.ds_tensor.ds_numel * param_shard_size
  196. if params_buffers is not None and params_buffers[i] is not None:
  197. assert params_buffers[i].numel(
  198. ) == param_size, f'param_buffers[{i}] size {params_buffers[i].numel()} does not match with param_size {param_size}'
  199. param_tensor = params_buffers[i]
  200. else:
  201. param_tensor = torch.empty(param_size, dtype=p.dtype, device=self.local_device,
  202. requires_grad=False).view(-1)
  203. param_tensors.append(param_tensor)
  204. # inter node all-gather
  205. inter_outputs = []
  206. inter_inputs = []
  207. for i, p in enumerate(params):
  208. inter_size = p.ds_tensor.ds_numel * inter_node_size
  209. _out = param_tensors[i].narrow(0, local_rank * inter_size, inter_size)
  210. inter_outputs.append(_out)
  211. inter_inputs.append(p.ds_tensor.data.view(-1).to(self.local_device))
  212. # sync enqueue
  213. dist.all_gather_coalesced(inter_outputs, inter_inputs, group=inter_node_comm_group, async_op=False)
  214. # intra node all-gather
  215. intra_outputs = []
  216. intra_inputs = []
  217. for i, p in enumerate(params):
  218. # partition param into multiple chunks for allgather
  219. # because inter-node all-gather outputs are in a continues memory
  220. # while in param memory, those inter-node data are placed in different
  221. # location.
  222. # each chunk is an intra-node output
  223. param_chunk = param_tensors[i].view(
  224. (inter_node_size, intra_node_size, p.ds_tensor.ds_numel)).narrow(1, local_rank, 1)
  225. param_chunk.copy_(inter_outputs[i].detach().clone().view(param_chunk.size()))
  226. output_chunks = torch.chunk(param_tensors[i], inter_node_size)
  227. for j, _out in enumerate(output_chunks):
  228. intra_chunk_size = intra_node_size * p.ds_tensor.ds_numel
  229. local_offset = local_rank * p.ds_tensor.ds_numel
  230. _in = param_tensors[i].narrow(0, j * intra_chunk_size + local_offset, p.ds_tensor.ds_numel)
  231. intra_outputs.append(_out)
  232. intra_inputs.append(_in)
  233. all_gather_handle = dist.all_gather_coalesced(intra_outputs,
  234. intra_inputs,
  235. group=intra_node_comm_group,
  236. async_op=True)
  237. for i, param in enumerate(params):
  238. param.data = param_tensors[i].narrow(0, 0, param.ds_numel).view(param.ds_shape).data
  239. return MiCS_AllGatherCoalescedHandle(
  240. allgather_handle=all_gather_handle,
  241. params=params,
  242. partitions=[],
  243. world_size=param_shard_size,
  244. )
  245. def get_partition_dp_group(self, param):
  246. return param.comm.param_shard_group
  247. def get_partition_rank(self):
  248. return self.mics_comm_groups.param_shard_rank
  249. @property
  250. def num_partitions(self):
  251. return self.mics_comm_groups.param_shard_size
  252. class MiCS_Offload(DeepSpeedZeRoOffload):
  253. """ Wrapper to change the behavior for parameter sharding
  254. """
  255. def __init__(self,
  256. module,
  257. timers,
  258. ds_config,
  259. overlap_comm=True,
  260. prefetch_bucket_size=50000000,
  261. max_reuse_distance=1000000000,
  262. max_live_parameters=1000000000,
  263. param_persistence_threshold=100000,
  264. model_persistence_threshold=sys.maxsize,
  265. offload_param_config=None,
  266. mpu=None):
  267. super().__init__(module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
  268. max_live_parameters, param_persistence_threshold, model_persistence_threshold,
  269. offload_param_config, mpu)
  270. def _convert_to_zero_parameters(self, ds_config, module, mpu):
  271. """ overload the parent class function for convert the parameters
  272. """
  273. log_dist(f'Convert to zero parameters from MiCS Offload manager', ranks=[0])
  274. non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
  275. if non_zero_params:
  276. zero_params = [p for p in module.parameters() if is_zero_param(p)]
  277. if zero_params:
  278. zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
  279. else:
  280. group = None
  281. if mpu:
  282. group = mpu.get_data_parallel_group()
  283. MiCS_Init(module=module,
  284. data_parallel_group=group,
  285. dtype=self.dtype,
  286. config_dict_or_path=ds_config,
  287. remote_device=self.offload_device,
  288. pin_memory=self.offload_param_pin_memory,
  289. mpu=mpu)
  290. class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
  291. """
  292. MiCS Optimizer
  293. """
  294. def __init__(self,
  295. module,
  296. init_optimizer,
  297. timers,
  298. ds_config,
  299. static_loss_scale=1,
  300. dynamic_loss_scale=False,
  301. dynamic_loss_args=None,
  302. verbose=True,
  303. contiguous_gradients=True,
  304. reduce_bucket_size=500000000,
  305. prefetch_bucket_size=50000000,
  306. max_reuse_distance=1000000000,
  307. max_live_parameters=1000000000,
  308. param_persistence_threshold=100000,
  309. model_persistence_threshold=sys.maxsize,
  310. dp_process_group=None,
  311. reduce_scatter=True,
  312. overlap_comm=False,
  313. offload_optimizer_config=None,
  314. offload_param_config=None,
  315. sub_group_size=1000000000000,
  316. mpu=None,
  317. clip_grad=0,
  318. gradient_accumulation_dtype=torch.float16,
  319. communication_data_type=torch.float16,
  320. postscale_gradients=True,
  321. gradient_predivide_factor=1,
  322. gradient_accumulation_steps=1,
  323. elastic_checkpoint=False,
  324. aio_config=None):
  325. log_dist("Init MiCS optimizer", ranks=[0])
  326. super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
  327. dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
  328. max_reuse_distance, max_live_parameters, param_persistence_threshold,
  329. model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
  330. offload_optimizer_config, offload_param_config, sub_group_size, mpu, clip_grad,
  331. gradient_accumulation_dtype, communication_data_type, postscale_gradients,
  332. gradient_predivide_factor, gradient_accumulation_steps, elastic_checkpoint, aio_config)
  333. first_param = next(module.parameters())
  334. # overload the dp_process_group and partition_count
  335. assert hasattr(first_param, "comm"), " ".join([
  336. "Sharded parameters don't have the MiCS_CommGroups attached.",
  337. "Might due to the use of deepspeed.zero.Init context for initializing the weights.",
  338. "To use MiCS sharding, please use deepspeed.zero.MiCS_Init instead for initializing parameter."
  339. ])
  340. self.dp_process_group = first_param.comm.param_shard_group
  341. self.partition_count = first_param.comm.param_shard_size
  342. def initialize_ds_offload(
  343. self,
  344. module,
  345. timers,
  346. ds_config,
  347. overlap_comm,
  348. prefetch_bucket_size,
  349. max_reuse_distance,
  350. max_live_parameters,
  351. param_persistence_threshold,
  352. model_persistence_threshold,
  353. offload_param_config,
  354. mpu,
  355. zpg=None,
  356. zero_quantized_weights=False,
  357. ):
  358. assert not zero_quantized_weights and zpg is None, "MiCS is mutually exclusive with ZeRO++"
  359. return MiCS_Offload(module, timers, ds_config, overlap_comm, prefetch_bucket_size, max_reuse_distance,
  360. max_live_parameters, param_persistence_threshold, model_persistence_threshold,
  361. offload_param_config, mpu)
  362. def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
  363. grad_buffers = super().partition_grads(params_to_release, grad_partitions)
  364. # perform all-reduce among replication groups
  365. # the function will perform accumulation boundary check
  366. self.allreduce_mics_shard_grads(params_to_release, grad_buffers)
  367. @instrument_w_nvtx
  368. def allreduce_mics_shard_grads(self, params, partitioned_grads_buffers: List[Tensor]):
  369. """
  370. """
  371. # TODO: improve the condition check
  372. if not self.is_gradient_accumulation_boundary or \
  373. len(partitioned_grads_buffers) == 0:
  374. return
  375. mics_comm_groups: MiCS_CommGroups = params[0].comm
  376. param_repli_group = mics_comm_groups.param_repli_group
  377. param_repli_size = mics_comm_groups.param_repli_size
  378. if param_repli_size is None or param_repli_size <= 1:
  379. return
  380. if not get_accelerator().on_accelerator(partitioned_grads_buffers[0]):
  381. raise RuntimeError("Local sharding has no support for CPU offloading")
  382. if dist.has_all_reduce_coalesced():
  383. scale_tensors(partitioned_grads_buffers, param_repli_size)
  384. dist.all_reduce_coalesced(tensors=partitioned_grads_buffers, group=param_repli_group)
  385. else:
  386. # manually coalescing all-reduce
  387. aggregated_buffer: Tensor = torch.cat(partitioned_grads_buffers)
  388. aggregated_buffer.div_(param_repli_size)
  389. dist.all_reduce(aggregated_buffer, group=param_repli_group)
  390. offset = 0
  391. for grad_buff in partitioned_grads_buffers:
  392. grad_buff.view(-1).copy_(aggregated_buffer.narrow(0, offset, grad_buff.numel()))
  393. offset += grad_buff.numel()
  394. def load_state_dict(self,
  395. state_dict_list,
  396. load_optimizer_states=True,
  397. load_from_fp32_weights=False,
  398. checkpoint_folder=None,
  399. load_serial=None):
  400. r""" Loading the ZeRO-3/MiCS partitioned checkpoints
  401. Because the self.dp_process_group is replaced with the communicator for
  402. partition group we can call the load_state_dict logic from ZeRO-3.
  403. """
  404. super().load_state_dict(state_dict_list, load_optimizer_states, load_from_fp32_weights, checkpoint_folder)