bias_add.py 876 B

1234567891011121314151617181920212223242526
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from typing import Optional
  5. import torch
  6. from deepspeed.ops.op_builder import SpatialInferenceBuilder
  7. spatial_cuda_module = None
  8. def nhwc_bias_add(activation: torch.Tensor,
  9. bias: torch.Tensor,
  10. other: Optional[torch.Tensor] = None,
  11. other_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  12. global spatial_cuda_module
  13. if spatial_cuda_module is None:
  14. spatial_cuda_module = SpatialInferenceBuilder().load()
  15. if other is None:
  16. return spatial_cuda_module.nhwc_bias_add(activation, bias)
  17. elif other_bias is None:
  18. return spatial_cuda_module.nhwc_bias_add_add(activation, bias, other)
  19. else:
  20. return spatial_cuda_module.nhwc_bias_add_bias_add(activation, bias, other, other_bias)