ragged_embedding.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from typing import Any, Dict, Optional
  5. import torch
  6. from deepspeed.accelerator import get_accelerator
  7. from ....allocator import empty_from
  8. from ....inference_utils import DtypeEnum
  9. from ....kernels.ragged_ops import RaggedEmbeddingKernel
  10. from ....ragged import RaggedBatchWrapper
  11. from ...interfaces import DSEmbeddingBase, DSEmbeddingRegistry
  12. from ...configs import DSEmbeddingsConfig
  13. @DSEmbeddingRegistry.register_module
  14. class DSRaggedEmbedding(DSEmbeddingBase):
  15. @staticmethod
  16. def name():
  17. return 'ragged_embedding'
  18. @staticmethod
  19. def supports_config(config: DSEmbeddingsConfig) -> bool:
  20. if DtypeEnum(config.residual_dtype) not in [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]:
  21. return False
  22. if config.use_token_type:
  23. return False
  24. if config.output_normalization is not None:
  25. return False
  26. try:
  27. _ = RaggedEmbeddingKernel(config.residual_dtype, torch.int32, config.embedding_dim)
  28. except ValueError:
  29. return False
  30. return True
  31. def __init__(self, config: DSEmbeddingsConfig, implementation_config: Dict[str, Any]) -> None:
  32. super().__init__(config, implementation_config)
  33. self.embed_offset = self._config.positional_offset
  34. # TODO(cmikeh2): How do we want to avoid the int32 vs int64 issue?
  35. self._ragged_embed = RaggedEmbeddingKernel(self._config.residual_dtype, torch.int32,
  36. self._config.embedding_dim)
  37. self._output = torch.empty((self._config.max_tokens, self._config.embedding_dim),
  38. dtype=self._config.residual_dtype,
  39. device=get_accelerator().current_device())
  40. @property
  41. def output(self) -> torch.Tensor:
  42. return self._output
  43. def forward(self,
  44. ragged_batch: RaggedBatchWrapper,
  45. word_embeddings: torch.Tensor,
  46. position_embeddings: Optional[torch.Tensor] = None) -> torch.Tensor:
  47. """
  48. Parameters:
  49. ragged_batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata.
  50. word_embeddings (torch.Tensor): The word embedding table
  51. """
  52. output = empty_from(self._output, (ragged_batch.tensor_toks, self._config.embedding_dim))
  53. self._ragged_embed(output,
  54. ragged_batch,
  55. word_embeddings,
  56. position_embed_weight=position_embeddings,
  57. position_embed_offset=self.embed_offset)
  58. return output