1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- """
- Implementation of a compressed sparse tensor. Similar in
- functionality to TensorFlow's IndexedSlices implementation.
- """
- import torch
- class SparseTensor(object):
- """ Compressed Sparse Tensor """
- def __init__(self, dense_tensor=None):
- self.orig_dense_tensor = dense_tensor
- self.is_sparse = dense_tensor.is_sparse
- if dense_tensor is not None:
- if dense_tensor.is_sparse:
- dense_tensor = dense_tensor.coalesce()
- self.indices = dense_tensor.indices().flatten()
- self.values = dense_tensor.values()
- else:
- result = torch.sum(dense_tensor, dim=1)
- self.indices = result.nonzero().flatten()
- self.values = dense_tensor[self.indices]
- self.dense_size = list(dense_tensor.size())
- else:
- self.indices = None
- self.values = None
- self.dense_size = None
- def to_coo_tensor(self):
- return torch.sparse_coo_tensor(self.indices.unsqueeze(0), self.values, self.dense_size)
- @staticmethod
- def type():
- return "deepspeed.SparseTensor"
- def to_dense(self):
- it = self.indices.unsqueeze(1)
- full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
- return self.values.new_zeros(self.dense_size).scatter_add_(0, full_indices, self.values)
- def sparse_size(self):
- index_size = list(self.indices.size())
- index_size = index_size[0]
- value_size = list(self.values.size())
- value_size = value_size[0] * value_size[1]
- dense_size = self.dense_size[0] * self.dense_size[1]
- return index_size + value_size, dense_size
- def add(self, b):
- assert self.dense_size == b.dense_size
- self.indices = torch.cat([self.indices, b.indices])
- self.values = torch.cat([self.values, b.values])
- def __str__(self):
- sparse_size, dense_size = self.sparse_size()
- return "DeepSpeed.SparseTensor(indices_size={}, values_size={}, " \
- "dense_size={}, device={}, reduction_factor={})".format(
- self.indices.size(), self.values.size(), self.dense_size,
- self.indices.get_device(), dense_size / sparse_size
- )
- def __repr__(self):
- return self.__str__()
|