test_compile.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import argparse
  5. import deepspeed
  6. from deepspeed.accelerator import get_accelerator
  7. from deepspeed import comm
  8. import torch
  9. import intel_extension_for_pytorch # noqa: F401 # type: ignore
  10. from torch.utils.data import Dataset, DataLoader
  11. torch._dynamo.config.cache_size_limit = 100
  12. def get_dynamo_stats():
  13. return torch._dynamo.utils.counters["graph_break"]
  14. class RandomDataset(Dataset):
  15. def __init__(self, size, length):
  16. self.len = length
  17. self.data = torch.randn(length, size).to(torch.bfloat16)
  18. def __getitem__(self, index):
  19. return self.data[index]
  20. def __len__(self):
  21. return self.len
  22. data_size = 1024
  23. data_length = 100
  24. rand_loader = DataLoader(dataset=RandomDataset(data_size, data_length), batch_size=1, shuffle=False)
  25. class MyModule(torch.nn.Module):
  26. def __init__(self, *args, **kwargs) -> None:
  27. super().__init__(*args, **kwargs)
  28. self.fc0 = torch.nn.Linear(1024, 256, bias=False)
  29. self.fc1 = torch.nn.Linear(256, 256, bias=False)
  30. self.dropout = torch.nn.Dropout(0.5)
  31. def forward(self, data, residual):
  32. output = residual + self.fc1(self.fc0(self.dropout(data))) * 0.5
  33. return output
  34. model = MyModule()
  35. params = model.parameters()
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
  38. parser.add_argument('--deepspeed_config',
  39. type=str,
  40. default='ds_config_z3.json',
  41. help='path to DeepSpeed configuration file')
  42. cmd_args = parser.parse_args()
  43. # initialize the DeepSpeed engine
  44. model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=params)
  45. model_engine.compile()
  46. residual = torch.rand(256, 256, dtype=torch.float).to(get_accelerator().current_device_name())
  47. start_stats = get_dynamo_stats()
  48. if comm.get_rank() == 0:
  49. #print(dynamo_stats['graph_breaks'])
  50. for item in start_stats.items():
  51. print(item)
  52. for step, batch in enumerate(rand_loader):
  53. if step % 10 == 0 and comm.get_rank() == 0:
  54. print(f'step={step}')
  55. # forward() method
  56. loss = model_engine(batch.to(get_accelerator().current_device_name()), residual).sum()
  57. # runs backpropagation
  58. model_engine.backward(loss)
  59. # weight update
  60. model_engine.step()
  61. dynamo_stats = get_dynamo_stats()
  62. if comm.get_rank() == 0:
  63. # print break down of graph break stats with markdown, print in table format, start with reason, then count
  64. # print a tag 'dynamo_output' before each line to allow post processing
  65. print("dynamo_output | Reason | Count |")
  66. print("dynamo_output | ------ | ----- |")
  67. for item in dynamo_stats.items():
  68. # replace '|' in item[0] with a literal '|' to avoid mess with table format
  69. item = (item[0].replace('|', r'\|'), item[1])
  70. print(f"dynamo_output | {item[0]} | {item[1]} |")
  71. print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |")