12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from abc import ABC, abstractmethod
- from packaging import version as pkg_version
- import torch
- class MetaTensorContainer(ABC):
- """
- NOTE: If you are using this feature with a container that
- also inherits from `HybridEngineContainer`, ensure that `MetaTensorContainer`
- is inherited before `HybridEngineContainer` in the class definition.
- """
- def __init__(self, **kwargs):
- if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
- raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
- super().__init__(**kwargs)
- self.is_meta = False
- self.ckpt_load_enabled = True
- def initialize_tensors(self, enable_training=False):
- super().initialize_tensors(enable_training=enable_training)
- self.is_meta = self.qkvw.is_meta
- def apply_tensor_parallelism(self, mp_replace, **kwargs):
- if self.is_meta:
- if self.qkvb is None:
- self.module.attention.attn_qkvb = None
- if self.dense_b is None:
- self.module.attention.attn_ob = None
- else:
- super().apply_tensor_parallelism(mp_replace, **kwargs)
- def copy_data_to_new_module(self):
- if self.is_meta:
- if self.attn_nw is None:
- self.module.mlp.attn_nw = self.attn_nw
- self.module.mlp.attn_nb = self.attn_nb
- else:
- super().copy_data_to_new_module()
- def transpose(self):
- if not self.is_meta:
- super().transpose()
- @abstractmethod
- def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
- """
- Load all the transformer parameter from the checkpoint file (sd).
- In addition to the parameter names, we require two
- more parameters to help read the data correctly
- from the checkpoint and split the qkv heads in the
- right order:
- 1. `use_load_prefix` (Default: False): this specifies
- whether we need to use the name of first abstraction
- layer of the model for searching the parameter's name
- in a checkpoint file. For more information of how this
- is used please see
- https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py
- 2. `split_qkv` (Default: True): we use this flag when splitting
- the qkv parameter into heads. If it is False, it means the heads
- of q, k, and v are stored together and needs to split in the
- DeepSpeed-Inference API.
- """
- raise NotImplementedError("A load_params() function must be defined in the model container \
- when inheriting the MetaTensorContainer feature")
|