test.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.pt.deepspeed_linear import LinearModuleForZeroStage3
  6. from deepspeed.pt.log_utils import logger
  7. from deepspeed.accelerator import get_accelerator
  8. def see_memory_usage(message):
  9. # Print message except when distributed but not rank 0
  10. logger.info(message)
  11. logger.info(
  12. "Memory Allocated %s GigaBytes ",
  13. get_accelerator().memory_allocated() / (1024 * 1024 * 1024),
  14. )
  15. logger.info(
  16. "Max Memory Allocated %s GigaBytes",
  17. get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),
  18. )
  19. logger.info(
  20. "Cache Allocated %s GigaBytes",
  21. get_accelerator().memory_cached() / (1024 * 1024 * 1024),
  22. )
  23. logger.info(
  24. "Max cache Allocated %s GigaBytes",
  25. get_accelerator().max_memory_cached() / (1024 * 1024 * 1024),
  26. )
  27. tens = torch.rand(1024, 16384, dtype=torch.half, device=torch.device(get_accelerator().device_name()))
  28. tens_back = tens.detach().clone()
  29. #linear_bk = torch.nn.functional.linear
  30. #torch.nn.functional.linear = deepspeed.pt.deepspeed_linear.LinearFunctionForZeroStage3.apply
  31. model = LinearModuleForZeroStage3(16384, 16384)
  32. model.to(get_accelerator().device_name()).half()
  33. see_memory_usage("Before forward")
  34. y = model(tens)
  35. see_memory_usage("After forward")
  36. model.weight.data = torch.zeros(1, dtype=torch.half, device=torch.device(get_accelerator().device_name()))
  37. see_memory_usage("After weight zero")
  38. y.backward(tens_back)