ds_bloom.py 750 B

1234567891011121314151617181920212223
  1. '''
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. '''
  4. from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
  5. class DeepSpeedBloomInference(DeepSpeedTransformerInference):
  6. """Initialize the DeepSpeed Bloom Transformer Layer.
  7. """
  8. def __init__(self,
  9. config,
  10. mp_group=None,
  11. quantize_scales=None,
  12. quantize_groups=1,
  13. merge_count=1,
  14. mlp_extra_grouping=False):
  15. super().__init__(config,
  16. mp_group,
  17. quantize_scales,
  18. quantize_groups,
  19. merge_count,
  20. mlp_extra_grouping)