ds_base.py 388 B

123456789101112131415
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch.nn as nn
  5. class DeepSpeedTransformerBase(nn.module):
  6. def __init__(self):
  7. pass
  8. # this would be the new clean base class that will replace DeepSpeedTransformerInference.
  9. # we currently don't know how this will look like but keeping it here as a placeholder.