real_accelerator.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. try:
  5. from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
  6. except ImportError as e:
  7. dsa1 = None
  8. try:
  9. from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa2
  10. except ImportError as e:
  11. dsa2 = None
  12. ds_accelerator = None
  13. def _validate_accelerator(accel_obj):
  14. # because abstract_accelerator has different path during
  15. # build time (accelerator.abstract_accelerator)
  16. # and run time (deepspeed.accelerator.abstract_accelerator)
  17. # and extension would import the
  18. # run time abstract_accelerator/DeepSpeedAccelerator as its base
  19. # class, so we need to compare accel_obj with both base class.
  20. # if accel_obj is instance of DeepSpeedAccelerator in one of
  21. # accelerator.abstractor_accelerator
  22. # or deepspeed.accelerator.abstract_accelerator, consider accel_obj
  23. # is a conforming object
  24. if not ((dsa1 != None and isinstance(accel_obj, dsa1)) or (dsa2 != None and isinstance(accel_obj, dsa2))):
  25. raise AssertionError(f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator')
  26. # TODO: turn off is_available test since this breaks tests
  27. #assert accel_obj.is_available(), \
  28. # f'{accel_obj.__class__.__name__} accelerator fails is_available() test'
  29. def get_accelerator():
  30. global ds_accelerator
  31. if ds_accelerator is None:
  32. try:
  33. from intel_extension_for_deepspeed import XPU_Accelerator
  34. except ImportError as e:
  35. pass
  36. else:
  37. ds_accelerator = XPU_Accelerator()
  38. _validate_accelerator(ds_accelerator)
  39. return ds_accelerator
  40. from .cuda_accelerator import CUDA_Accelerator
  41. ds_accelerator = CUDA_Accelerator()
  42. _validate_accelerator(ds_accelerator)
  43. return ds_accelerator
  44. def set_accelerator(accel_obj):
  45. global ds_accelerator
  46. _validate_accelerator(accel_obj)
  47. ds_accelerator = accel_obj
  48. '''
  49. -----------[code] test_get.py -----------
  50. from deepspeed.accelerator import get_accelerator
  51. my_accelerator = get_accelerator()
  52. print(f'{my_accelerator._name=}')
  53. print(f'{my_accelerator._communication_backend=}')
  54. print(f'{my_accelerator.HalfTensor().device=}')
  55. print(f'{my_accelerator.total_memory()=}')
  56. -----------[code] test_get.py -----------
  57. ---[output] python test_get.py---------
  58. my_accelerator.name()='cuda'
  59. my_accelerator.communication_backend='nccl'
  60. my_accelerator.HalfTensor().device=device(type='cuda', index=0)
  61. my_accelerator.total_memory()=34089730048
  62. ---[output] python test_get.py---------
  63. **************************************************************************
  64. -----------[code] test_set.py -----------
  65. from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
  66. cu_accel = CUDA_Accelerator()
  67. print(f'{id(cu_accel)=}')
  68. from deepspeed.accelerator import set_accelerator, get_accelerator
  69. set_accelerator(cu_accel)
  70. my_accelerator = get_accelerator()
  71. print(f'{id(my_accelerator)=}')
  72. print(f'{my_accelerator._name=}')
  73. print(f'{my_accelerator._communication_backend=}')
  74. print(f'{my_accelerator.HalfTensor().device=}')
  75. print(f'{my_accelerator.total_memory()=}')
  76. -----------[code] test_set.py -----------
  77. ---[output] python test_set.py---------
  78. id(cu_accel)=139648165478304
  79. my_accelerator=<deepspeed.accelerator.cuda_accelerator.CUDA_Accelerator object at 0x7f025f4bffa0>
  80. my_accelerator.name='cuda'
  81. my_accelerator.communication_backend='nccl'
  82. my_accelerator.HalfTensor().device=device(type='cuda', index=0)
  83. my_accelerator.total_memory()=34089730048
  84. ---[output] python test_set.py---------
  85. '''