linear_base.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import abstractmethod
  5. from typing import Any, Dict, Optional, Type
  6. import torch
  7. from deepspeed.runtime.config_utils import DeepSpeedConfigModel
  8. from ..ds_module import DSModuleBase
  9. from ..module_registry import DSModuleRegistryBase
  10. from ..configs import DSLinearConfig
  11. from ...inference_parameter import InferenceParameter
  12. class DSLinearBase(DSModuleBase):
  13. """
  14. Base mixin for all Linear modules. The interface represented by this module
  15. is:
  16. hidden_out = activation(hidden_in * weight + bias)
  17. The format and dtype of the weight and bias tensors are not defined and implementations
  18. may compress as necessary. Must support a bias.
  19. """
  20. @staticmethod
  21. def config_class() -> Type[DeepSpeedConfigModel]:
  22. return DSLinearConfig
  23. def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None:
  24. super().__init__(config, implementation_config)
  25. @abstractmethod
  26. def transform_param(self, param: torch.Tensor) -> InferenceParameter:
  27. """
  28. Perform any necessary transformations of the parameters of this module.
  29. Parameters:
  30. param (torch.Tensor): Weight or bias tensor.
  31. """
  32. ...
  33. def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor:
  34. """
  35. Parameters:
  36. hidden_states (torch.Tensor): Hidden states tensor. Expected shape is either
  37. [batch, seq_len, in_channels] or [batch, in_channels].
  38. Returns:
  39. torch.Tensor: Output tensor. Tensor should have same number of dimensions as
  40. input tensor.
  41. """
  42. raise NotImplementedError()
  43. @property
  44. @abstractmethod
  45. def output(self) -> torch.Tensor:
  46. """
  47. Return the padded, pre-allocated output Tensor.
  48. """
  49. ...
  50. class DSLinearRegistry(DSModuleRegistryBase):
  51. registry: Dict = {}
  52. @staticmethod
  53. def associated_class() -> Type[DSModuleBase]:
  54. return DSLinearBase