ds_megatron_gpt.py 682 B

1234567891011121314151617181920
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
  5. class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference):
  6. """Initialize the DeepSpeed Megatron GPT Transformer Layer.
  7. """
  8. def __init__(self,
  9. config,
  10. mp_group=None,
  11. quantize_scales=None,
  12. quantize_groups=1,
  13. merge_count=1,
  14. mlp_extra_grouping=False):
  15. super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)