moe_base.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from abc import abstractmethod
  5. from typing import Any, Dict, Optional, Type
  6. import torch
  7. from deepspeed.runtime.config_utils import DeepSpeedConfigModel
  8. from ..ds_module import DSModuleBase
  9. from ..module_registry import DSModuleRegistryBase
  10. from ..configs import DSMoEConfig
  11. from ...inference_parameter import InferenceParameter
  12. class DSMoEBase(DSModuleBase):
  13. """
  14. Base mixing for MoE modules. The interface represented by this module is:
  15. expert_assignments = gate(hidden_states)
  16. intermediate = ragged_linear(hidden_states, expert_assignments)
  17. output = ragged_linear(intermediate, expert_assignments)
  18. """
  19. @staticmethod
  20. def config_class() -> Type[DeepSpeedConfigModel]:
  21. return DSMoEConfig
  22. def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None:
  23. super().__init__(config, implementation_config)
  24. @abstractmethod
  25. def transform_gate_param(self, param: torch.Tensor) -> InferenceParameter:
  26. """
  27. Perform any necessary transformations of the gate parameter.
  28. Args:
  29. param (torch.Tensor): gate_w (shape: [num_experts, model_dim])
  30. """
  31. ...
  32. @abstractmethod
  33. def transform_moe_mlp_1_param(self, param: torch.Tensor) -> InferenceParameter:
  34. """
  35. Perform any necessary transformations of the parameter. The specific component
  36. being transformed should be inferred from the shape of the parameter.
  37. Args:
  38. param (torch.Tensor): One of either mlp_1_w, mlp_1_b
  39. """
  40. ...
  41. @abstractmethod
  42. def transform_moe_mlp_2_param(self, param: torch.Tensor) -> InferenceParameter:
  43. """
  44. Perform any necessary transformations of the parameter. The specified component being
  45. transformed should be inferred from the shape of the parameter. This interface is
  46. separate from transform_moe_1_param because the two components may have identical
  47. shapes.
  48. Args:
  49. param (torch.Tensor): One of either mlp_2_w or mlp_2_b
  50. """
  51. ...
  52. def forward(self,
  53. hidden_states: torch.Tensor,
  54. gate_w: torch.Tensor,
  55. mlp_1_w: torch.Tensor,
  56. mlp_2_w: torch.Tensor,
  57. mlp_1_b: Optional[torch.Tensor] = None,
  58. mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor:
  59. raise NotImplementedError()
  60. @property
  61. @abstractmethod
  62. def output(self) -> torch.Tensor:
  63. """
  64. Returns the pre-allocated, padded output Tensor.
  65. """
  66. ...
  67. class DSMoERegistry(DSModuleRegistryBase):
  68. registry: Dict = {}
  69. @staticmethod
  70. def associated_class() -> Type[DSModuleBase]:
  71. return DSMoEBase