test_csr.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import random
  6. from deepspeed.runtime.sparse_tensor import SparseTensor
  7. def test_csr_addition_self():
  8. row_count = 10
  9. random.seed(1234)
  10. x = torch.ones(1, 5)
  11. for i in range(row_count - 1):
  12. if random.random() > 0.75:
  13. x = torch.cat([x, torch.ones(1, 5)])
  14. else:
  15. x = torch.cat([x, torch.zeros(1, 5)])
  16. dense_x = x.clone()
  17. cx = SparseTensor(x)
  18. assert torch.all(dense_x == cx.to_dense())
  19. cx.add(cx)
  20. assert torch.all(dense_x + dense_x == cx.to_dense())
  21. def test_csr_addition_different():
  22. row_count = 10
  23. random.seed(1234)
  24. x = torch.ones(1, 5)
  25. for i in range(row_count - 1):
  26. if random.random() > 0.75:
  27. x = torch.cat([x, torch.ones(1, 5)])
  28. else:
  29. x = torch.cat([x, torch.zeros(1, 5)])
  30. dense_x = x.clone()
  31. cx = SparseTensor(x)
  32. y = torch.ones(1, 5)
  33. for i in range(row_count - 1):
  34. if random.random() > 0.75:
  35. y = torch.cat([y, torch.ones(1, 5)])
  36. else:
  37. y = torch.cat([y, torch.zeros(1, 5)])
  38. dense_y = y.clone()
  39. cy = SparseTensor(y)
  40. dense_sum = dense_x + dense_y
  41. cx.add(cy)
  42. assert torch.all(dense_sum == cx.to_dense())