123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- """
- Copyright 2022 The Microsoft DeepSpeed Team
- """
- import torch
- from dataclasses import dataclass
- from deepspeed import comm as dist
- @dataclass
- class fragment_address:
- numel: int
- start: int
- @dataclass
- class tensor_fragment:
- lp_fragment: torch.Tensor
- lp_fragment_address: fragment_address
- hp_fragment: torch.Tensor
- hp_fragment_address: fragment_address
- optim_fragment: {}
- gradient_dict: {}
- offload_gradient_dict: {}
- use_offload: bool
- param_group_index: int
- def update_hp(self):
- self.hp_fragment.data.copy_(self.lp_fragment.data)
- def update_lp(self):
- self.lp_fragment.data.copy_(self.hp_fragment.data)
- def get_optim_state_fragment(self, key):
- if key in self.optim_fragment:
- return self.optim_fragment[key]
- else:
- raise ValueError(f'{key} not found in optimizer state fragment')
- def get_hp_fragment_address(self):
- return self.hp_fragment_address
- def get_optim_state_keys(self):
- return list(self.optim_fragment.keys())
- def get_full_hp_param(self, optim_state_key=None):
- reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
- if self._hp_mapping is not None:
- lp_frag_address = self._hp_mapping.lp_fragment_address
- reduce_fragment = torch.narrow(reduce_buffer,
- 0,
- lp_frag_address.start,
- lp_frag_address.numel)
- if optim_state_key is None:
- hp_fragment = self._hp_mapping.hp_fragment
- else:
- hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)
- reduce_fragment.data.copy_(hp_fragment.data)
- dist.all_reduce(reduce_buffer, group=self._dp_group)
- return reduce_buffer.reshape_as(self)
- def get_full_hp_grad(self):
- reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
- if self._hp_mapping is not None:
- hp_mapping = self._hp_mapping
- if hp_mapping.use_offload:
- gradient_dict = hp_mapping.offload_gradient_dict
- else:
- gradient_dict = hp_mapping.gradient_dict
- if hp_mapping.param_group_index not in gradient_dict or gradient_dict[
- hp_mapping.param_group_index] is None:
- raise ValueError(
- "Gradients are only available immediately after backward and before engine step"
- )
- lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][
- self._index_in_param_group]
- hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
- lp_frag_address = self._hp_mapping.lp_fragment_address
- reduce_fragment = torch.narrow(reduce_buffer,
- 0,
- lp_frag_address.start,
- lp_frag_address.numel)
- if self.view(-1).shape == hp_grad_fragment.shape:
- reduce_buffer.data.copy_(hp_grad_fragment.data)
- else:
- reduce_fragment.data.copy_(hp_grad_fragment.data)
- dist.all_reduce(reduce_buffer, group=self._dp_group)
- return reduce_buffer.reshape_as(self)
- def safe_get_full_fp32_param(param):
- """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- """
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- return param._z3_optimizer.get_full_hp_param(param)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- return param.get_full_hp_param()
- return None
- def safe_get_full_optimizer_state(param, optim_state_key):
- """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- """
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- return param._z3_optimizer.get_full_hp_param(param, optim_state_key)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- return param.get_full_hp_param(optim_state_key)
- return None
- # TODO: Figure out the correct return dtype
- def safe_get_full_grad(param):
- """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
- Args:
- param (``torch.nn.Parameter``): A model parameter
- """
- if param.grad is not None:
- return param.grad
- # ZeRO stage 3 param
- if hasattr(param, 'ds_id'):
- return param._z3_optimizer.get_fp32_grad_for_param(param)
- # ZeRO stage 1, 2, and bf16_optimizer params
- if hasattr(param, '_hp_mapping'):
- return param.get_full_hp_grad()
- return None
- def get_hp_fragment_mapping(lp_param,
- lp_start,
- flat_hp_partition,
- gradient_dict,
- offload_gradient_dict,
- use_offload,
- param_group_index,
- partition_start,
- partition_size,
- optimizer_state_dict):
- lp_end = lp_param.numel() + lp_start
- hp_start = partition_start
- hp_end = partition_start + partition_size
- fragment_start = max(lp_start, hp_start)
- fragment_end = min(lp_end, hp_end)
- assert fragment_start < fragment_end, \
- f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
- fragment_numel = fragment_end - fragment_start
- hp_frag_address = fragment_address(start=fragment_start - hp_start,
- numel=fragment_numel)
- hp_fragment_tensor = flat_hp_partition.narrow(0,
- hp_frag_address.start,
- hp_frag_address.numel)
- optim_fragment = {
- key: value.narrow(0,
- hp_frag_address.start,
- hp_frag_address.numel)
- for key,
- value in optimizer_state_dict.items()
- if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
- }
- lp_frag_address = fragment_address(start=fragment_start - lp_start,
- numel=fragment_numel)
- lp_fragment_tensor = lp_param.flatten().narrow(0,
- lp_frag_address.start,
- lp_frag_address.numel)
- return tensor_fragment(lp_fragment=lp_fragment_tensor,
- lp_fragment_address=lp_frag_address,
- hp_fragment=hp_fragment_tensor,
- hp_fragment_address=hp_frag_address,
- optim_fragment=optim_fragment,
- gradient_dict=gradient_dict,
- offload_gradient_dict=offload_gradient_dict,
- use_offload=use_offload,
- param_group_index=param_group_index)
- '''
- Logic for lp_param to hp_param mapping
- lp lp0 lp1 lp2 lp3 lp4 <------- indices/names
- lp [ ][ ][ ][ ][ ] <-------- tensors
- flat_lp [ ] <-------- flat lp params
- flat_hp [ ] <------------------ flat hp partition on current rank
- full_hp [ ] <------- full flat hp params
- lp2
- full numel = 16
- lp_frag
- numel = 12
- frag_start = 3
- frag_end = 15
- hp_frag
- numel = 12
- frag_start = 0
- frag_end = 11
- hp_frag.copy_(lp_frag)
- lp3:
- full numel = 4
- lp_frag
- numel = 4
- start = 0
- end = 3
- hp_frag
- numel = 4
- start = 12
- end = 15
- lp4:
- full numel = 12
- lp_frag
- numel = 4
- start = 0
- end = 3
- hp_frag
- numel = 4
- start = 16
- end = 19
- Visual depiction of above
- lp { }
- flat_lp [ ]
- flat_hp ( )
- flat_lp [ { ( } ) ]
- lx hx ly hy
- ly-hx
- lp { }
- flat_lp [ ]
- flat_hp ( )
- flat_lp [ ( { ) } ]
- hx lx hy ly
- hy-lx
- lp { }
- flat_lp [ ]
- flat_hp ( )
- flat_lp [ ( { } ) ]
- hx lx ly hy
- ly-lx
- lp -> (lx, hy)
- flat_hp -> (hx, hy)
- '''
|