sparse_tensor.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Implementation of a compressed sparse tensor. Similar in
  4. functionality to TensorFlow's IndexedSlices implementation.
  5. """
  6. import torch
  7. class SparseTensor(object):
  8. """ Compressed Sparse Tensor """
  9. def __init__(self, dense_tensor=None):
  10. self.orig_dense_tensor = dense_tensor
  11. self.is_sparse = dense_tensor.is_sparse
  12. if dense_tensor is not None:
  13. if dense_tensor.is_sparse:
  14. dense_tensor = dense_tensor.coalesce()
  15. self.indices = dense_tensor.indices().flatten()
  16. self.values = dense_tensor.values()
  17. else:
  18. result = torch.sum(dense_tensor, dim=1)
  19. self.indices = result.nonzero().flatten()
  20. self.values = dense_tensor[self.indices]
  21. self.dense_size = list(dense_tensor.size())
  22. else:
  23. self.indices = None
  24. self.values = None
  25. self.dense_size = None
  26. def to_coo_tensor(self):
  27. return torch.sparse_coo_tensor(self.indices.unsqueeze(0),
  28. self.values,
  29. self.dense_size)
  30. @staticmethod
  31. def type():
  32. return "deepspeed.SparseTensor"
  33. def to_dense(self):
  34. it = self.indices.unsqueeze(1)
  35. full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
  36. return self.values.new_zeros(self.dense_size).scatter_add_(
  37. 0,
  38. full_indices,
  39. self.values)
  40. def sparse_size(self):
  41. index_size = list(self.indices.size())
  42. index_size = index_size[0]
  43. value_size = list(self.values.size())
  44. value_size = value_size[0] * value_size[1]
  45. dense_size = self.dense_size[0] * self.dense_size[1]
  46. return index_size + value_size, dense_size
  47. def add(self, b):
  48. assert self.dense_size == b.dense_size
  49. self.indices = torch.cat([self.indices, b.indices])
  50. self.values = torch.cat([self.values, b.values])
  51. def __str__(self):
  52. sparse_size, dense_size = self.sparse_size()
  53. return "DeepSpeed.SparseTensor(indices_size={}, values_size={}, " \
  54. "dense_size={}, device={}, reduction_factor={})".format(
  55. self.indices.size(), self.values.size(), self.dense_size,
  56. self.indices.get_device(), dense_size / sparse_size
  57. )
  58. def __repr__(self):
  59. return self.__str__()