inference_parameter.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from typing import Dict
  5. import torch
  6. CORE_PARAM = "_ds_core_param_key"
  7. STR_TO_DTYPE = {
  8. "torch.float32": torch.float32,
  9. "torch.float64": torch.float64,
  10. "torch.float16": torch.float16,
  11. "torch.bfloat16": torch.bfloat16,
  12. "torch.int64": torch.int64,
  13. "torch.int32": torch.int32,
  14. "torch.int16": torch.int16,
  15. "torch.int8": torch.int8,
  16. "torch.uint8": torch.uint8,
  17. "torch.bool": torch.bool,
  18. }
  19. class InferenceParameter(torch.Tensor):
  20. """
  21. An extension of the torch.Tensor class to support our inference focused features. One important
  22. thing to note here is that an InferenceParam can be used a torch.Tensor, but outputs of
  23. torch.Tensor operations will not be InferenceParams.
  24. """
  25. @staticmethod
  26. def __new__(cls, tensor, *args, **kwargs):
  27. new_tensor = super().__new__(cls, tensor, *args, **kwargs)
  28. if hasattr(tensor, "_aux_attrs"):
  29. setattr(new_tensor, "_aux_attrs", tensor.aux_attrs)
  30. return new_tensor
  31. def to(self, *args, **kwargs):
  32. new_tensor = super().to(*args, **kwargs)
  33. if hasattr(self, "_aux_attrs"):
  34. setattr(new_tensor, "_aux_attrs", self.aux_attrs)
  35. try:
  36. _ = torch.device(args[0])
  37. for name, attr in new_tensor.aux_attrs.items():
  38. new_attr = attr.to(*args, **kwargs)
  39. setattr(new_tensor, name, new_attr)
  40. new_tensor.aux_attrs[name] = new_attr
  41. except:
  42. pass
  43. return new_tensor
  44. @classmethod
  45. def initialize(cls, core_param: torch.Tensor, **kwargs) -> 'InferenceParameter':
  46. """
  47. Create the inference parameter.
  48. """
  49. param = InferenceParameter(core_param)
  50. setattr(param, "_aux_attrs", kwargs)
  51. for attr_name, attr in kwargs.items():
  52. if hasattr(param, attr_name):
  53. raise ValueError(f"Attribute {attr_name} already exists on param.")
  54. if not isinstance(attr, torch.Tensor):
  55. raise ValueError(f"Attribute {attr_name} must be a tensor.")
  56. setattr(param, attr_name, attr)
  57. return param
  58. @classmethod
  59. def initialize_raw(self, **kwargs) -> 'InferenceParameter':
  60. """
  61. All kwargs must be torch.Tensors and must include the core parameter.
  62. """
  63. if CORE_PARAM not in kwargs:
  64. raise ValueError(f"Must provide core parameter, with key {CORE_PARAM}.")
  65. return InferenceParameter.initialize(kwargs[CORE_PARAM], **kwargs)
  66. @property
  67. def aux_attrs(self) -> Dict[str, torch.Tensor]:
  68. """
  69. Dictionary of auxiliary attributes.
  70. """
  71. return self._aux_attrs