test_flops_profiler.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import pytest
  6. import deepspeed
  7. from deepspeed.profiling.flops_profiler import get_model_profile
  8. from unit.simple_model import SimpleModel, random_dataloader
  9. from unit.common import DistributedTest
  10. from deepspeed.utils.torch import required_torch_version
  11. from deepspeed.accelerator import get_accelerator
  12. if torch.half not in get_accelerator().supported_dtypes():
  13. pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
  14. pytestmark = pytest.mark.skipif(not required_torch_version(min_version=1.3),
  15. reason='requires Pytorch version 1.3 or above')
  16. def within_range(val, target, tolerance):
  17. return abs(val - target) / target < tolerance
  18. TOLERANCE = 0.05
  19. class LeNet5(torch.nn.Module):
  20. def __init__(self, n_classes):
  21. super(LeNet5, self).__init__()
  22. self.feature_extractor = torch.nn.Sequential(
  23. torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
  24. torch.nn.Tanh(),
  25. torch.nn.AvgPool2d(kernel_size=2),
  26. torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
  27. torch.nn.Tanh(),
  28. torch.nn.AvgPool2d(kernel_size=2),
  29. torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
  30. torch.nn.Tanh(),
  31. )
  32. self.classifier = torch.nn.Sequential(
  33. torch.nn.Linear(in_features=120, out_features=84),
  34. torch.nn.Tanh(),
  35. torch.nn.Linear(in_features=84, out_features=n_classes),
  36. )
  37. def forward(self, x):
  38. x = self.feature_extractor(x)
  39. x = torch.flatten(x, 1)
  40. logits = self.classifier(x)
  41. probs = torch.nn.functional.softmax(logits, dim=1)
  42. return logits, probs
  43. class TestFlopsProfiler(DistributedTest):
  44. world_size = 1
  45. def test(self):
  46. config_dict = {
  47. "train_batch_size": 1,
  48. "steps_per_print": 1,
  49. "optimizer": {
  50. "type": "Adam",
  51. "params": {
  52. "lr": 0.001,
  53. }
  54. },
  55. "zero_optimization": {
  56. "stage": 0
  57. },
  58. "fp16": {
  59. "enabled": True,
  60. },
  61. "flops_profiler": {
  62. "enabled": True,
  63. "step": 1,
  64. "module_depth": -1,
  65. "top_modules": 3,
  66. },
  67. }
  68. hidden_dim = 10
  69. model = SimpleModel(hidden_dim, empty_grad=False)
  70. model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
  71. data_loader = random_dataloader(model=model,
  72. total_samples=50,
  73. hidden_dim=hidden_dim,
  74. device=model.device,
  75. dtype=torch.half)
  76. for n, batch in enumerate(data_loader):
  77. loss = model(batch[0], batch[1])
  78. model.backward(loss)
  79. model.step()
  80. if n == 3: break
  81. assert within_range(model.flops_profiler.flops, 200, tolerance=TOLERANCE)
  82. assert model.flops_profiler.params == 110
  83. def test_flops_profiler_in_inference(self):
  84. mod = LeNet5(10)
  85. batch_size = 1024
  86. input = torch.randn(batch_size, 1, 32, 32)
  87. flops, macs, params = get_model_profile(
  88. mod,
  89. tuple(input.shape),
  90. print_profile=True,
  91. detailed=True,
  92. module_depth=-1,
  93. top_modules=3,
  94. warm_up=1,
  95. as_string=False,
  96. ignore_modules=None,
  97. )
  98. print(flops, macs, params)
  99. assert within_range(flops, 866076672, TOLERANCE)
  100. assert within_range(macs, 426516480, TOLERANCE)
  101. assert params == 61706