123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from dataclasses import dataclass
- from deepspeed import comm as dist
- from typing import Dict
- @dataclass
- class fragment_address:
- numel: int
- start: int
- @dataclass
- class tensor_fragment:
- lp_fragment: torch.Tensor
- lp_fragment_address: fragment_address
- hp_fragment: torch.Tensor
- hp_fragment_address: fragment_address
- optim_fragment: Dict
- gradient_dict: Dict
- offload_gradient_dict: Dict
- use_offload: bool
- param_group_index: int
- def update_hp(self):
- self.hp_fragment.data.copy_(self.lp_fragment.data)
- def update_lp(self):
- self.lp_fragment.data.copy_(self.hp_fragment.data)
- def get_optim_state_fragment(self, key):
- if key in self.optim_fragment:
- return self.optim_fragment[key]
- else:
- raise ValueError(f'{key} not found in optimizer state fragment')
- def get_hp_fragment_address(self):
- return self.hp_fragment_address
- def get_optim_state_keys(self):
- return list(self.optim_fragment.keys())
- def get_hp_fragment(self, optim_state_key=None):
- if optim_state_key is None:
- return self.hp_fragment
- return self.get_optim_state_fragment(optim_state_key)
- def get_full_hp_param(self, optim_state_key=None):
- reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
- if self._hp_mapping is not None:
- lp_frag_address = self._hp_mapping.lp_fragment_address
- reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
- hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
- reduce_fragment.data.copy_(hp_fragment.data)
- dist.all_reduce(reduce_buffer, group=self._dp_group)
- return reduce_buffer.reshape_as(self)
- def set_full_hp_param(self, value, optim_state_key=None):
- if self._hp_mapping is not None:
- lp_frag_address = self._hp_mapping.lp_fragment_address
- value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
- hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
- hp_fragment.data.copy_(value_fragment.data)
- def get_full_hp_grad(self):
- reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
- if self._hp_mapping is not None:
- hp_mapping = self._hp_mapping
- if hp_mapping.use_offload:
- gradient_dict = hp_mapping.offload_gradient_dict
- else:
- gradient_dict = hp_mapping.gradient_dict
- if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
- raise ValueError("Gradients are only available immediately after backward and before engine step")
- lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
- hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
- lp_frag_address = self._hp_mapping.lp_fragment_address
- reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
- if self.view(-1).shape == hp_grad_fragment.shape:
- reduce_buffer.data.copy_(hp_grad_fragment.data)
- else:
- reduce_fragment.data.copy_(hp_grad_fragment.data)
- dist.all_reduce(reduce_buffer, group=self._dp_group)
- return reduce_buffer.reshape_as(self)
- def safe_get_full_fp32_param(param):
- """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- """
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- return param._z3_optimizer.get_full_hp_param(param)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- return param.get_full_hp_param()
- return None
- def safe_set_full_fp32_param(param, value):
- """Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- value (``torch.Tensor``): New value
- """
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- param._z3_optimizer.set_full_hp_param(value, param)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- param.set_full_hp_param(value)
- def safe_get_full_optimizer_state(param, optim_state_key):
- """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
- """
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- return param._z3_optimizer.get_full_hp_param(param, optim_state_key)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- return param.get_full_hp_param(optim_state_key)
- return None
- def safe_set_full_optimizer_state(param, value, optim_state_key):
- """Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- value (``torch.Tensor``): New value
- optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
- """
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- param._z3_optimizer.set_full_hp_param(value, param, optim_state_key)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- param.set_full_hp_param(value, optim_state_key)
- # TODO: Figure out the correct return dtype
- def safe_get_full_grad(param):
- """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- """
- if param.grad is not None:
- return param.grad
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- return param._z3_optimizer.get_fp32_grad_for_param(param)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- return param.get_full_hp_grad()
- return None
- # TODO: Implement API for setting ZeRO partitioned gradients
- def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
- param_group_index, partition_start, partition_size, optimizer_state_dict):
- lp_end = lp_param.numel() + lp_start
- hp_start = partition_start
- hp_end = partition_start + partition_size
- fragment_start = max(lp_start, hp_start)
- fragment_end = min(lp_end, hp_end)
- assert fragment_start < fragment_end, \
- f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
- fragment_numel = fragment_end - fragment_start
- hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
- hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
- optim_fragment = {
- key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel)
- for key, value in optimizer_state_dict.items()
- if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
- }
- lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
- lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
- return tensor_fragment(lp_fragment=lp_fragment_tensor,
- lp_fragment_address=lp_frag_address,
- hp_fragment=hp_fragment_tensor,
- hp_fragment_address=hp_frag_address,
- optim_fragment=optim_fragment,
- gradient_dict=gradient_dict,
- offload_gradient_dict=offload_gradient_dict,
- use_offload=use_offload,
- param_group_index=param_group_index)
- '''
- Logic for lp_param to hp_param mapping
- lp lp0 lp1 lp2 lp3 lp4 <------- indices/names
- lp [ ][ ][ ][ ][ ] <-------- tensors
- flat_lp [ ] <-------- flat lp params
- flat_hp [ ] <------------------ flat hp partition on current rank
- full_hp [ ] <------- full flat hp params
- lp2
- full numel = 16
- lp_frag
- numel = 12
- frag_start = 3
- frag_end = 15
- hp_frag
- numel = 12
- frag_start = 0
- frag_end = 11
- hp_frag.copy_(lp_frag)
- lp3:
- full numel = 4
- lp_frag
- numel = 4
- start = 0
- end = 3
- hp_frag
- numel = 4
- start = 12
- end = 15
- lp4:
- full numel = 12
- lp_frag
- numel = 4
- start = 0
- end = 3
- hp_frag
- numel = 4
- start = 16
- end = 19
- Visual depiction of above
- lp { }
- flat_lp [ ]
- flat_hp ( )
- flat_lp [ { ( } ) ]
- lx hx ly hy
- ly-hx
- lp { }
- flat_lp [ ]
- flat_hp ( )
- flat_lp [ ( { ) } ]
- hx lx hy ly
- hy-lx
- lp { }
- flat_lp [ ]
- flat_hp ( )
- flat_lp [ ( { } ) ]
- hx lx ly hy
- ly-lx
- lp -> (lx, hy)
- flat_hp -> (hx, hy)
- '''
|