init_on_device.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. from typing import Callable
  6. from torch import Tensor
  7. from packaging import version as pkg_version
  8. class OnDevice(object):
  9. """
  10. Create modules/tensors w. specific devices and dtypes. Examples:
  11. Create MyModule which consists of many different sub-modules and parameters. In this case we can create
  12. MyModule as a collection of 'meta' tensors by passing `device='meta'` or we can create the module _directly_
  13. on a CUDA device by passing `device=f'cuda:{local_rank}'` (where `local_rank` is the local GPU id.
  14. with OnDevice(dtype=torch.float16, device='meta'):
  15. model = MyModel()
  16. with OnDevice(dtype=torch.float16, device=f'cuda:{local_rank}'):
  17. model = MyModel()
  18. """
  19. _orig_torch_empty = torch.empty
  20. _orig_torch_zeros = torch.zeros
  21. _orig_torch_ones = torch.ones
  22. _orig_torch_full = torch.full
  23. def __init__(self, dtype, device="meta", enabled=True):
  24. self.dtype = dtype
  25. self.enabled = enabled
  26. self.device = device
  27. if device == "meta":
  28. if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
  29. raise NotImplementedError(
  30. "Meta tensor support is not available, please upgrade to torch 1.10+"
  31. )
  32. def fp_tensor_constructor(self,
  33. fn: Callable,
  34. target_fp_dtype: torch.dtype) -> Callable:
  35. def wrapped_fn(*args, **kwargs) -> Tensor:
  36. if kwargs.get("device", None) is None:
  37. kwargs['device'] = self.device
  38. tensor: Tensor = fn(*args, **kwargs)
  39. if tensor.is_floating_point():
  40. tensor = tensor.to(target_fp_dtype)
  41. return tensor
  42. return wrapped_fn
  43. def get_new_tensor_fn_for_dtype(self, dtype: torch.dtype) -> Callable:
  44. def new_tensor(cls, *args) -> Tensor:
  45. tensor = OnDevice._orig_torch_empty(0, device=self.device).new_empty(*args)
  46. if tensor.is_floating_point():
  47. tensor = tensor.to(dtype)
  48. return tensor
  49. return new_tensor
  50. def __enter__(self):
  51. if not self.enabled:
  52. return
  53. torch.Tensor.__old_new__ = torch.Tensor.__new__
  54. torch.Tensor.__new__ = self.get_new_tensor_fn_for_dtype(self.dtype)
  55. torch.empty = self.fp_tensor_constructor(self._orig_torch_empty, self.dtype)
  56. torch.zeros = self.fp_tensor_constructor(self._orig_torch_zeros, self.dtype)
  57. torch.ones = self.fp_tensor_constructor(self._orig_torch_ones, self.dtype)
  58. torch.full = self.fp_tensor_constructor(self._orig_torch_full, self.dtype)
  59. def __exit__(self, exc_type, exc_value, traceback):
  60. if not self.enabled:
  61. return
  62. torch.Tensor.__new__ = torch.Tensor.__old_new__
  63. torch.empty = self._orig_torch_empty
  64. torch.zeros = self._orig_torch_zeros
  65. torch.ones = self._orig_torch_ones
  66. torch.full = self._orig_torch_full