meta_tensor.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import ABC, abstractmethod
  5. from packaging import version as pkg_version
  6. import torch
  7. class MetaTensorContainer(ABC):
  8. """
  9. NOTE: If you are using this feature with a container that
  10. also inherits from `HybridEngineContainer`, ensure that `MetaTensorContainer`
  11. is inherited before `HybridEngineContainer` in the class definition.
  12. """
  13. def __init__(self, **kwargs):
  14. if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
  15. raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
  16. super().__init__(**kwargs)
  17. self.is_meta = False
  18. self.ckpt_load_enabled = True
  19. def initialize_tensors(self, enable_training=False):
  20. super().initialize_tensors(enable_training=enable_training)
  21. self.is_meta = self.qkvw.is_meta
  22. def apply_tensor_parallelism(self, mp_replace, **kwargs):
  23. if self.is_meta:
  24. if self.qkvb is None:
  25. self.module.attention.attn_qkvb = None
  26. if self.dense_b is None:
  27. self.module.attention.attn_ob = None
  28. else:
  29. super().apply_tensor_parallelism(mp_replace, **kwargs)
  30. def copy_data_to_new_module(self):
  31. if self.is_meta:
  32. if self.attn_nw is None:
  33. self.module.mlp.attn_nw = self.attn_nw
  34. self.module.mlp.attn_nb = self.attn_nb
  35. else:
  36. super().copy_data_to_new_module()
  37. def transpose(self):
  38. if not self.is_meta:
  39. super().transpose()
  40. @abstractmethod
  41. def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
  42. """
  43. Load all the transformer parameter from the checkpoint file (sd).
  44. In addition to the parameter names, we require two
  45. more parameters to help read the data correctly
  46. from the checkpoint and split the qkv heads in the
  47. right order:
  48. 1. `use_load_prefix` (Default: False): this specifies
  49. whether we need to use the name of first abstraction
  50. layer of the model for searching the parameter's name
  51. in a checkpoint file. For more information of how this
  52. is used please see
  53. https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py
  54. 2. `split_qkv` (Default: True): we use this flag when splitting
  55. the qkv parameter into heads. If it is False, it means the heads
  56. of q, k, and v are stored together and needs to split in the
  57. DeepSpeed-Inference API.
  58. """
  59. raise NotImplementedError("A load_params() function must be defined in the model container \
  60. when inheriting the MetaTensorContainer feature")