base.py 536 B

1234567891011121314151617181920
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from ..config import DeepSpeedInferenceConfig
  6. from deepspeed.ops.op_builder import InferenceBuilder
  7. class BaseOp(torch.nn.Module):
  8. inference_module = None
  9. def __init__(self, config: DeepSpeedInferenceConfig):
  10. super(BaseOp, self).__init__()
  11. self.config = config
  12. if BaseOp.inference_module is None:
  13. builder = InferenceBuilder()
  14. BaseOp.inference_module = builder.load()