cupy.py 657 B

123456789101112131415161718192021222324
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import cupy
  5. from torch.utils.dlpack import to_dlpack
  6. from torch.utils.dlpack import from_dlpack
  7. class CupyBackend(object):
  8. def __init__(self):
  9. pass
  10. def torch2cupy(self, tensor):
  11. return cupy.fromDlpack(to_dlpack(tensor))
  12. def cupy2torch(self, cupy_tensor):
  13. return from_dlpack(cupy_tensor.toDlpack())
  14. def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
  15. packed_sign = cupy.packbits(cupy_bool_tensor)
  16. sign_list_packed = cupy.split(packed_sign, num_chunks)
  17. cupy.cuda.get_current_stream().synchronize()
  18. return sign_list_packed