sparse_tensor.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Implementation of a compressed sparse tensor. Similar in
  6. functionality to TensorFlow's IndexedSlices implementation.
  7. """
  8. import torch
  9. class SparseTensor(object):
  10. """ Compressed Sparse Tensor """
  11. def __init__(self, dense_tensor=None):
  12. self.orig_dense_tensor = dense_tensor
  13. self.dtype = self.orig_dense_tensor.dtype
  14. self.is_sparse = dense_tensor.is_sparse
  15. if dense_tensor is not None:
  16. if dense_tensor.is_sparse:
  17. dense_tensor = dense_tensor.coalesce()
  18. self.indices = dense_tensor.indices().flatten()
  19. self.values = dense_tensor.values()
  20. else:
  21. result = torch.sum(dense_tensor, dim=1)
  22. self.indices = result.nonzero().flatten()
  23. self.values = dense_tensor[self.indices]
  24. self.dense_size = list(dense_tensor.size())
  25. else:
  26. self.indices = None
  27. self.values = None
  28. self.dense_size = None
  29. def to_coo_tensor(self):
  30. return torch.sparse_coo_tensor(self.indices.unsqueeze(0), self.values, self.dense_size)
  31. @staticmethod
  32. def type():
  33. return "deepspeed.SparseTensor"
  34. def to_dense(self):
  35. it = self.indices.unsqueeze(1)
  36. full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
  37. return self.values.new_zeros(self.dense_size).scatter_add_(0, full_indices, self.values)
  38. def sparse_size(self):
  39. index_size = list(self.indices.size())
  40. index_size = index_size[0]
  41. value_size = list(self.values.size())
  42. value_size = value_size[0] * value_size[1]
  43. dense_size = self.dense_size[0] * self.dense_size[1]
  44. return index_size + value_size, dense_size
  45. def add(self, b):
  46. assert self.dense_size == b.dense_size
  47. self.indices = torch.cat([self.indices, b.indices])
  48. self.values = torch.cat([self.values, b.values])
  49. def __str__(self):
  50. sparse_size, dense_size = self.sparse_size()
  51. return "DeepSpeed.SparseTensor(indices_size={}, values_size={}, " \
  52. "dense_size={}, device={}, reduction_factor={})".format(
  53. self.indices.size(), self.values.size(), self.dense_size,
  54. self.indices.get_device(), dense_size / sparse_size
  55. )
  56. def __repr__(self):
  57. return self.__str__()