test_csr.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import torch
  2. import random
  3. from deepspeed.runtime.csr_tensor import CSRTensor
  4. def test_csr_addition_self():
  5. row_count = 10
  6. random.seed(1234)
  7. x = torch.ones(1, 5)
  8. for i in range(row_count - 1):
  9. if random.random() > 0.75:
  10. x = torch.cat([x, torch.ones(1, 5)])
  11. else:
  12. x = torch.cat([x, torch.zeros(1, 5)])
  13. dense_x = x.clone()
  14. cx = CSRTensor(x)
  15. assert torch.all(dense_x == cx.to_dense())
  16. cx.add(cx)
  17. assert torch.all(dense_x + dense_x == cx.to_dense())
  18. def test_csr_addition_different():
  19. row_count = 10
  20. random.seed(1234)
  21. x = torch.ones(1, 5)
  22. for i in range(row_count - 1):
  23. if random.random() > 0.75:
  24. x = torch.cat([x, torch.ones(1, 5)])
  25. else:
  26. x = torch.cat([x, torch.zeros(1, 5)])
  27. dense_x = x.clone()
  28. cx = CSRTensor(x)
  29. y = torch.ones(1, 5)
  30. for i in range(row_count - 1):
  31. if random.random() > 0.75:
  32. y = torch.cat([y, torch.ones(1, 5)])
  33. else:
  34. y = torch.cat([y, torch.zeros(1, 5)])
  35. dense_y = y.clone()
  36. cy = CSRTensor(y)
  37. dense_sum = dense_x + dense_y
  38. cx.add(cy)
  39. assert torch.all(dense_sum == cx.to_dense())