stochastic_transformer.py 565 B

12345678910111213141516171819202122
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .transformer import TransformerBuilder
  5. class StochasticTransformerBuilder(TransformerBuilder):
  6. BUILD_VAR = "DS_BUILD_STOCHASTIC_TRANSFORMER"
  7. NAME = "stochastic_transformer"
  8. def __init__(self):
  9. super().__init__(name=self.NAME)
  10. def absolute_name(self):
  11. return f'deepspeed.ops.transformer.{self.NAME}_op'
  12. def nvcc_args(self):
  13. args = super().nvcc_args()
  14. args.append('-D__STOCHASTIC_MODE__')
  15. return args