stochastic_transformer.py 521 B

1234567891011121314151617181920
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. from .transformer import TransformerBuilder
  5. class StochasticTransformerBuilder(TransformerBuilder):
  6. BUILD_VAR = "DS_BUILD_STOCHASTIC_TRANSFORMER"
  7. NAME = "stochastic_transformer"
  8. def __init__(self):
  9. super().__init__(name=self.NAME)
  10. def absolute_name(self):
  11. return f'deepspeed.ops.transformer.{self.NAME}_op'
  12. def nvcc_args(self):
  13. args = super().nvcc_args()
  14. args.append('-D__STOCHASTIC_MODE__')
  15. return args