123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import os
- import re
- import torch
- import types
- from typing import List, Tuple, Union
- from dataclasses import dataclass
- from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE)
- @dataclass
- class SubparamShape:
- patterns: List[str]
- shape: Tuple[Union[Tuple[int], int]]
- partition_dim: int
- def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
- hp_mapping = self._hp_mapping
- hp_mapping.optim_fragment = {}
- hp_keys = []
- for file in os.listdir(folder):
- # We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
- pattern = r'(.+).pt'
- match = re.search(pattern, file)
- if match:
- hp_keys.append(match.group(1))
- step = None
- for key in hp_keys:
- ckpt_file = os.path.join(folder, f"{key}.pt")
- ckpt_dict = torch.load(ckpt_file)
- if key == "step":
- step = ckpt_dict
- continue
- full_hp_param = ckpt_dict[PARAM]
- # need to deal with slices that were averaged.
- # the opposite of averaging here becomes an exact copy of the first slice
- # I thought of 2 ways:
- # implementation a. find a way for a client to pass a dict with patterns
- # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
- # tp_rank = 0
- # tp_world_size = 1
- # the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
- # self.shape that means we automatically copy?
- # implementation b.
- # this version requires no additional data passed from the client
- # if the shapes already match it must be slices that were averaged - so we just hack around those
- if full_hp_param.shape == self.shape:
- tp_rank = 0
- tp_world_size = 1
- # special case for word_embeddings weights which get padded differently depending on TP degree.
- # the converter to universal currently strips the original padding completely so the saved
- # weight is padding-free and we just need to add new padding depending on the target TP
- # degree
- is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
- if is_vocab_tensor:
- # In the absence of data passed from the user wrt new padded vocab specific to tp degree
- # we can again derive that data by reverse engineering the target shapes like so:
- padded_target_vocab_size = self.shape[0] * tp_world_size
- assert padded_target_vocab_size >= full_hp_param.shape[0], \
- f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}'
- if padded_target_vocab_size > full_hp_param.shape[0]:
- padding_size = padded_target_vocab_size - full_hp_param.shape[0]
- full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
- full_param_numel = full_hp_param.numel()
- tp_slice_numel = self.numel()
- # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
- # print_rank_0(f'{full_hp_param[:10]=}', force=True)
- assert full_param_numel == tp_world_size * tp_slice_numel, \
- f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
- # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
- # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
- sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None)
- # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
- # special case is when a single parameter is effectively a container for multiple sub parameters
- # (more details at PARAM_N_SUB_PARAMS definition)
- chunk_dim = ckpt_dict.get(CAT_DIM, 0)
- n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
- if sub_param_shape:
- partition_dim = sub_param_shape.partition_dim
- sub_dim_sizes = sub_param_shape.shape[partition_dim]
- if not isinstance(sub_dim_sizes, tuple):
- sub_dim_sizes = (sub_dim_sizes, )
- partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape]
- full_hp_param = full_hp_param.view(partition_shape)
- offset = 0
- merged_chunks = []
- for sub_dim_size in sub_dim_sizes:
- sub_params_tp_slice = full_hp_param.narrow(partition_dim,
- offset, sub_dim_size).chunk(tp_world_size,
- dim=partition_dim)[tp_rank]
- merged_chunks.append(sub_params_tp_slice)
- offset += sub_dim_size
- tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim)
- elif n_sub_params > 1:
- sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
- sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
- tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
- else:
- # this performs the opposite of cat when merging TP slices
- tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]
- tp_hp_slice = tp_hp_slice.flatten()
- lp_frag_address = hp_mapping.lp_fragment_address
- tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
- # print(f"{key} SHAPE: {tp_hp_slice.shape=}")
- # print(f"{key} SHAPE: {dst_tensor.shape=}")
- # print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
- if key == FP32_WEIGHT_KEY:
- dst_tensor = hp_mapping.get_hp_fragment()
- assert dst_tensor.numel() == lp_frag_address.numel, \
- f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
- dst_tensor.data.copy_(tp_hp_fragment.data)
- else:
- assert tp_hp_fragment.numel() == lp_frag_address.numel, \
- f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}'
- hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()
- return step
- def enable_universal_checkpoint(param_list):
- for param in param_list:
- param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param)
|