init_on_device.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from typing import Callable
  6. from torch import Tensor
  7. from packaging import version as pkg_version
  8. class OnDevice(object):
  9. """
  10. Create modules/tensors w. specific devices and dtypes. Examples:
  11. Create MyModule which consists of many different sub-modules and parameters. In this case we can create
  12. MyModule as a collection of 'meta' tensors by passing `device='meta'` or we can create the module _directly_
  13. on a CUDA device by passing `device=f'cuda:{local_rank}'` (where `local_rank` is the local GPU id.
  14. with OnDevice(dtype=torch.float16, device='meta'):
  15. model = MyModel()
  16. with OnDevice(dtype=torch.float16, device=f'cuda:{local_rank}'):
  17. model = MyModel()
  18. """
  19. _orig_torch_empty = torch.empty
  20. _orig_torch_zeros = torch.zeros
  21. _orig_torch_ones = torch.ones
  22. _orig_torch_full = torch.full
  23. def __init__(self, dtype, device="meta", enabled=True):
  24. self.dtype = dtype
  25. self.enabled = enabled
  26. self.device = device
  27. if device == "meta":
  28. if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
  29. raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
  30. def fp_tensor_constructor(self, fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
  31. def wrapped_fn(*args, **kwargs) -> Tensor:
  32. if kwargs.get("device", None) is None:
  33. kwargs['device'] = self.device
  34. tensor: Tensor = fn(*args, **kwargs)
  35. if tensor.is_floating_point():
  36. tensor = tensor.to(target_fp_dtype)
  37. return tensor
  38. return wrapped_fn
  39. def get_new_tensor_fn_for_dtype(self, dtype: torch.dtype) -> Callable:
  40. def new_tensor(cls, *args) -> Tensor:
  41. tensor = OnDevice._orig_torch_empty(0, device=self.device).new_empty(*args)
  42. if tensor.is_floating_point():
  43. tensor = tensor.to(dtype)
  44. return tensor
  45. return new_tensor
  46. def __enter__(self):
  47. if not self.enabled:
  48. return
  49. torch.Tensor.__old_new__ = torch.Tensor.__new__
  50. torch.Tensor.__new__ = self.get_new_tensor_fn_for_dtype(self.dtype)
  51. torch.empty = self.fp_tensor_constructor(self._orig_torch_empty, self.dtype)
  52. torch.zeros = self.fp_tensor_constructor(self._orig_torch_zeros, self.dtype)
  53. torch.ones = self.fp_tensor_constructor(self._orig_torch_ones, self.dtype)
  54. torch.full = self.fp_tensor_constructor(self._orig_torch_full, self.dtype)
  55. def __exit__(self, exc_type, exc_value, traceback):
  56. if not self.enabled:
  57. return
  58. torch.Tensor.__new__ = torch.Tensor.__old_new__
  59. torch.empty = self._orig_torch_empty
  60. torch.zeros = self._orig_torch_zeros
  61. torch.ones = self._orig_torch_ones
  62. torch.full = self._orig_torch_full