123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from abc import abstractmethod
- from typing import Any, Dict, Optional, Tuple, Type
- import torch
- from deepspeed.runtime.config_utils import DeepSpeedConfigModel
- from ..ds_module import DSModuleBase
- from ..configs.norm_config import DSNormConfig
- from ..module_registry import DSModuleRegistryBase
- from ...inference_parameter import InferenceParameter
- class DSPostNormBase(DSModuleBase):
- """
- Base MixIn for all Post-Normalization modules. The interface represented by this
- module is:
- residual, hidden_out = norm(residual + hidden_in)
- If residual and hidden_out are the same data type, then they may alias each other.
- Furthermore, residual should be updated in-place.
- """
- @staticmethod
- def config_class() -> Type[DeepSpeedConfigModel]:
- return DSNormConfig
- def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]) -> None:
- super().__init__(config, implementation_config)
- @abstractmethod
- def transform_param(self, param: torch.Tensor) -> InferenceParameter:
- """
- Transform a gamma/beta parameter. It is assumed that both transformations are
- the same.
- Parameters:
- param (torch.Tensor): Gamma or beta parameter.
- """
- ...
- def forward(self,
- residual: torch.Tensor,
- hidden_states: torch.Tensor,
- gamma: torch.Tensor,
- beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Parameters:
- residual (torch.Tensor): Residual tensor.
- hidden_states (torch.Tensor): Hidden states tensor.
- Returns:
- (torch.Tensor, torch.Tensor): Tuple of residual and hidden states.
- Hidden states may alias with residual.
- """
- raise NotImplementedError()
- class DSPostNormRegistry(DSModuleRegistryBase):
- registry: Dict = {}
- @staticmethod
- def associated_class() -> Type[DSModuleBase]:
- return DSPostNormBase
|