post_norm_base.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import abstractmethod
  5. from typing import Any, Dict, Optional, Tuple, Type
  6. import torch
  7. from deepspeed.runtime.config_utils import DeepSpeedConfigModel
  8. from ..ds_module import DSModuleBase
  9. from ..configs.norm_config import DSNormConfig
  10. from ..module_registry import DSModuleRegistryBase
  11. from ...inference_parameter import InferenceParameter
  12. class DSPostNormBase(DSModuleBase):
  13. """
  14. Base MixIn for all Post-Normalization modules. The interface represented by this
  15. module is:
  16. residual, hidden_out = norm(residual + hidden_in)
  17. If residual and hidden_out are the same data type, then they may alias each other.
  18. Furthermore, residual should be updated in-place.
  19. """
  20. @staticmethod
  21. def config_class() -> Type[DeepSpeedConfigModel]:
  22. return DSNormConfig
  23. def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]) -> None:
  24. super().__init__(config, implementation_config)
  25. @abstractmethod
  26. def transform_param(self, param: torch.Tensor) -> InferenceParameter:
  27. """
  28. Transform a gamma/beta parameter. It is assumed that both transformations are
  29. the same.
  30. Parameters:
  31. param (torch.Tensor): Gamma or beta parameter.
  32. """
  33. ...
  34. def forward(self,
  35. residual: torch.Tensor,
  36. hidden_states: torch.Tensor,
  37. gamma: torch.Tensor,
  38. beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
  39. """
  40. Parameters:
  41. residual (torch.Tensor): Residual tensor.
  42. hidden_states (torch.Tensor): Hidden states tensor.
  43. Returns:
  44. (torch.Tensor, torch.Tensor): Tuple of residual and hidden states.
  45. Hidden states may alias with residual.
  46. """
  47. raise NotImplementedError()
  48. class DSPostNormRegistry(DSModuleRegistryBase):
  49. registry: Dict = {}
  50. @staticmethod
  51. def associated_class() -> Type[DSModuleBase]:
  52. return DSPostNormBase