test.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import torch
  2. from deepspeed.runtime.zero.contiguous_memory_allocator import ContiguousMemoryAllocator
  3. def test1():
  4. mem = ContiguousMemoryAllocator(1024, torch.half, 'cpu')
  5. mem.print_allocation(resolution=100)
  6. a1 = mem.allocate_tensor(64).mul_(0.0).add_(1.0)
  7. mem.print_allocation(resolution=100)
  8. mem.release_tensor(a1)
  9. mem.print_allocation(resolution=100)
  10. a2 = mem.allocate_tensor(64).mul_(0.0).add_(2.0)
  11. a3 = mem.allocate_tensor(256).mul_(0.0).add_(3.0)
  12. a4 = mem.allocate_tensor(128).mul_(0.0).add_(4.0)
  13. mem.print_allocation(resolution=100)
  14. mem.release_tensor(a3)
  15. mem.print_allocation(resolution=100)
  16. a5 = mem.allocate_tensor(64).mul_(0.0).add_(5.0)
  17. a6 = mem.allocate_tensor(256).mul_(0.0).add_(6.0)
  18. a7 = mem.allocate_tensor(128).mul_(0.0).add_(7.0)
  19. mem.print_allocation(resolution=100)
  20. a8 = mem.allocate_tensor(256).mul_(0.0).add_(8.0)
  21. a9 = mem.allocate_tensor(128).mul_(0.0).add_(9.0)
  22. mem.print_allocation(resolution=100)
  23. mem.release_tensor(a9)
  24. mem.release_tensor(a6)
  25. mem.release_tensor(a2)
  26. mem.release_tensor(a5)
  27. a10 = mem.allocate_tensor(512).mul_(0.0).add_(10.0)
  28. mem.print_allocation(resolution=100)
  29. #print(f"a4:{a4}")
  30. #print(f"a7:{a7}")
  31. #print(f"a8:{a8}")
  32. #print(f"a10:{a10}")
  33. assert (a4.norm() + a7.norm() + a8.norm() + a10.norm()).item() == 474.50, "Test failed"
  34. def test2():
  35. mem = ContiguousMemoryAllocator(512, torch.half, 'cpu')
  36. a1 = mem.allocate_tensor(64).mul_(0.0).add_(1.0)
  37. a2 = mem.allocate_tensor(64).mul_(0.0).add_(2.0)
  38. a3 = mem.allocate_tensor(64).mul_(0.0).add_(3.0)
  39. a4 = mem.allocate_tensor(64).mul_(0.0).add_(4.0)
  40. a5 = mem.allocate_tensor(64).mul_(0.0).add_(5.0)
  41. a6 = mem.allocate_tensor(64).mul_(0.0).add_(6.0)
  42. a7 = mem.allocate_tensor(64).mul_(0.0).add_(7.0)
  43. a8 = mem.allocate_tensor(64).mul_(0.0).add_(8.0)
  44. mem.release_tensor(a2)
  45. mem.release_tensor(a4)
  46. mem.release_tensor(a6)
  47. mem.release_tensor(a8)
  48. mem.print_allocation(resolution=100)
  49. a9 = mem.allocate_tensor(128).mul_(0.0).add_(9.0)
  50. a10 = mem.allocate_tensor(64).mul_(0.0).add_(10.0)
  51. a11 = mem.allocate_tensor(64).mul_(0.0).add_(11.0)
  52. mem.release_tensor(a1)
  53. mem.release_tensor(a5)
  54. mem.print_allocation(resolution=100)
  55. a12 = mem.allocate_tensor(128).mul_(0.0).add_(12.0)
  56. mem.print_allocation(resolution=100)
  57. print(f"a7:{a7}")
  58. print(f"a9:{a9}")
  59. print(f"a10:{a10}")
  60. print(f"a11:{a11}")
  61. print(f"a12:{a12}")
  62. assert (a7.norm() + a9.norm() + a10.norm() + a11.norm() + a12.norm()) == 460.75, "TestFailed"
  63. test1()
  64. test2()