test_nhwc_bias_add.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. import torch
  6. import deepspeed
  7. from deepspeed.ops.op_builder import SpatialInferenceBuilder
  8. from deepspeed.ops.transformer.inference.bias_add import nhwc_bias_add
  9. from deepspeed.accelerator import get_accelerator
  10. if not deepspeed.ops.__compatible_ops__[SpatialInferenceBuilder.NAME]:
  11. pytest.skip("Inference ops are not available on this system", allow_module_level=True)
  12. def allclose(x, y):
  13. assert x.dtype == y.dtype
  14. rtol, atol = {torch.float32: (5e-3, 5e-4), torch.float16: (3e-2, 2e-3), torch.int8: (1, 1)}[x.dtype]
  15. return torch.allclose(x, y, rtol=rtol, atol=atol)
  16. def ref_bias_add(activations, bias):
  17. return activations + bias.reshape(1, -1, 1, 1)
  18. channels_list = [192, 384, 320, 576, 640, 768, 960, 1152, 1280, 1536, 1600, 1920, 2240, 2560]
  19. @pytest.mark.inference_ops
  20. @pytest.mark.parametrize("batch", [1, 2, 10])
  21. @pytest.mark.parametrize("image_size", [16, 32, 64])
  22. @pytest.mark.parametrize("channels", channels_list)
  23. def test_bias_add(batch, image_size, channels):
  24. activations = torch.randn((batch, channels, image_size, image_size),
  25. dtype=torch.float16,
  26. device=get_accelerator().device_name()).to(memory_format=torch.channels_last)
  27. bias = torch.randn((channels), dtype=torch.float16, device=get_accelerator().device_name())
  28. ref_vals = ref_bias_add(activations.clone().detach(), bias)
  29. ds_vals = nhwc_bias_add(activations, bias)
  30. assert allclose(ds_vals, ref_vals)
  31. def ref_bias_add_add(activations, bias, other):
  32. return (activations + bias.reshape(1, -1, 1, 1)) + other
  33. @pytest.mark.inference_ops
  34. @pytest.mark.parametrize("batch", [1, 2, 10])
  35. @pytest.mark.parametrize("image_size", [16, 32, 64])
  36. @pytest.mark.parametrize("channels", channels_list)
  37. def test_bias_add_add(batch, image_size, channels):
  38. activations = torch.randn((batch, channels, image_size, image_size),
  39. dtype=torch.float16,
  40. device=get_accelerator().device_name()).to(memory_format=torch.channels_last)
  41. other = torch.randn((batch, channels, image_size, image_size),
  42. dtype=torch.float16,
  43. device=get_accelerator().device_name()).to(memory_format=torch.channels_last)
  44. bias = torch.randn((channels), dtype=torch.float16, device=get_accelerator().device_name())
  45. ref_vals = ref_bias_add_add(activations.clone().detach(), bias, other)
  46. ds_vals = nhwc_bias_add(activations, bias, other=other)
  47. assert allclose(ds_vals, ref_vals)
  48. def ref_bias_add_bias_add(activations, bias, other, other_bias):
  49. return (activations + bias.reshape(1, -1, 1, 1)) + (other + other_bias.reshape(1, -1, 1, 1))
  50. @pytest.mark.inference_ops
  51. @pytest.mark.parametrize("batch", [1, 2, 10])
  52. @pytest.mark.parametrize("image_size", [16, 32, 64])
  53. @pytest.mark.parametrize("channels", channels_list)
  54. def test_bias_add_bias_add(batch, image_size, channels):
  55. activations = torch.randn((batch, channels, image_size, image_size),
  56. dtype=torch.float16,
  57. device=get_accelerator().device_name()).to(memory_format=torch.channels_last)
  58. other = torch.randn((batch, channels, image_size, image_size),
  59. dtype=torch.float16,
  60. device=get_accelerator().device_name()).to(memory_format=torch.channels_last)
  61. bias = torch.randn((channels), dtype=torch.float16, device=get_accelerator().device_name())
  62. other_bias = torch.randn((channels), dtype=torch.float16, device=get_accelerator().device_name())
  63. ref_vals = ref_bias_add_bias_add(activations.clone().detach(), bias, other, other_bias)
  64. ds_vals = nhwc_bias_add(activations, bias, other=other, other_bias=other_bias)
  65. assert allclose(ds_vals, ref_vals)