cupy.py 701 B

1234567891011121314151617181920212223242526
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import cupy
  5. from torch.utils.dlpack import to_dlpack
  6. from torch.utils.dlpack import from_dlpack
  7. class CupyBackend(object):
  8. def __init__(self):
  9. pass
  10. def torch2cupy(self, tensor):
  11. return cupy.fromDlpack(to_dlpack(tensor))
  12. def cupy2torch(self, cupy_tensor):
  13. return from_dlpack(cupy_tensor.toDlpack())
  14. def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
  15. packed_sign = cupy.packbits(cupy_bool_tensor)
  16. sign_list_packed = cupy.split(packed_sign, num_chunks)
  17. cupy.cuda.get_current_stream().synchronize()
  18. return sign_list_packed