sparse_tensor.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Implementation of a compressed sparse tensor. Similar in
  6. functionality to TensorFlow's IndexedSlices implementation.
  7. """
  8. import torch
  9. class SparseTensor(object):
  10. """ Compressed Sparse Tensor """
  11. def __init__(self, dense_tensor=None):
  12. self.orig_dense_tensor = dense_tensor
  13. self.is_sparse = dense_tensor.is_sparse
  14. if dense_tensor is not None:
  15. if dense_tensor.is_sparse:
  16. dense_tensor = dense_tensor.coalesce()
  17. self.indices = dense_tensor.indices().flatten()
  18. self.values = dense_tensor.values()
  19. else:
  20. result = torch.sum(dense_tensor, dim=1)
  21. self.indices = result.nonzero().flatten()
  22. self.values = dense_tensor[self.indices]
  23. self.dense_size = list(dense_tensor.size())
  24. else:
  25. self.indices = None
  26. self.values = None
  27. self.dense_size = None
  28. def to_coo_tensor(self):
  29. return torch.sparse_coo_tensor(self.indices.unsqueeze(0), self.values, self.dense_size)
  30. @staticmethod
  31. def type():
  32. return "deepspeed.SparseTensor"
  33. def to_dense(self):
  34. it = self.indices.unsqueeze(1)
  35. full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
  36. return self.values.new_zeros(self.dense_size).scatter_add_(0, full_indices, self.values)
  37. def sparse_size(self):
  38. index_size = list(self.indices.size())
  39. index_size = index_size[0]
  40. value_size = list(self.values.size())
  41. value_size = value_size[0] * value_size[1]
  42. dense_size = self.dense_size[0] * self.dense_size[1]
  43. return index_size + value_size, dense_size
  44. def add(self, b):
  45. assert self.dense_size == b.dense_size
  46. self.indices = torch.cat([self.indices, b.indices])
  47. self.values = torch.cat([self.values, b.values])
  48. def __str__(self):
  49. sparse_size, dense_size = self.sparse_size()
  50. return "DeepSpeed.SparseTensor(indices_size={}, values_size={}, " \
  51. "dense_size={}, device={}, reduction_factor={})".format(
  52. self.indices.size(), self.values.size(), self.dense_size,
  53. self.indices.get_device(), dense_size / sparse_size
  54. )
  55. def __repr__(self):
  56. return self.__str__()