aio_config.py 1.1 KB

1234567891011121314151617181920212223242526272829
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.runtime.config_utils import get_scalar_param
  5. from deepspeed.runtime.swap_tensor.constants import *
  6. AIO_DEFAULT_DICT = {
  7. AIO_BLOCK_SIZE: AIO_BLOCK_SIZE_DEFAULT,
  8. AIO_QUEUE_DEPTH: AIO_QUEUE_DEPTH_DEFAULT,
  9. AIO_THREAD_COUNT: AIO_THREAD_COUNT_DEFAULT,
  10. AIO_SINGLE_SUBMIT: AIO_SINGLE_SUBMIT_DEFAULT,
  11. AIO_OVERLAP_EVENTS: AIO_OVERLAP_EVENTS_DEFAULT
  12. }
  13. def get_aio_config(param_dict):
  14. if AIO in param_dict.keys() and param_dict[AIO] is not None:
  15. aio_dict = param_dict[AIO]
  16. return {
  17. AIO_BLOCK_SIZE: get_scalar_param(aio_dict, AIO_BLOCK_SIZE, AIO_BLOCK_SIZE_DEFAULT),
  18. AIO_QUEUE_DEPTH: get_scalar_param(aio_dict, AIO_QUEUE_DEPTH, AIO_QUEUE_DEPTH_DEFAULT),
  19. AIO_THREAD_COUNT: get_scalar_param(aio_dict, AIO_THREAD_COUNT, AIO_THREAD_COUNT_DEFAULT),
  20. AIO_SINGLE_SUBMIT: get_scalar_param(aio_dict, AIO_SINGLE_SUBMIT, AIO_SINGLE_SUBMIT_DEFAULT),
  21. AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT)
  22. }
  23. return AIO_DEFAULT_DICT