test.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. from deepspeed.pt.deepspeed_linear import LinearModuleForZeroStage3
  3. from deepspeed.pt.deepspeed_utils import see_memory_usage
  4. from deepspeed.pt.log_utils import logger
  5. import deepspeed
  6. def see_memory_usage(message):
  7. # Print message except when distributed but not rank 0
  8. logger.info(message)
  9. logger.info(
  10. "Memory Allocated %s GigaBytes ",
  11. torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
  12. )
  13. logger.info(
  14. "Max Memory Allocated %s GigaBytes",
  15. torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
  16. )
  17. logger.info(
  18. "Cache Allocated %s GigaBytes",
  19. torch.cuda.memory_cached() / (1024 * 1024 * 1024),
  20. )
  21. logger.info(
  22. "Max cache Allocated %s GigaBytes",
  23. torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
  24. )
  25. tens = torch.rand(1024, 16384, dtype=torch.half, device=torch.device('cuda'))
  26. tens_back = tens.detach().clone()
  27. #linear_bk = torch.nn.functional.linear
  28. #torch.nn.functional.linear = deepspeed.pt.deepspeed_linear.LinearFunctionForZeroStage3.apply
  29. model = LinearModuleForZeroStage3(16384, 16384)
  30. model.cuda().half()
  31. see_memory_usage("Before forward")
  32. y = model(tens)
  33. see_memory_usage("After forward")
  34. model.weight.data = torch.zeros(1, dtype=torch.half, device=torch.device('cuda'))
  35. see_memory_usage("After weight zero")
  36. y.backward(tens_back)