ds_aio_handle.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
  6. """
  7. import torch
  8. import os
  9. import time
  10. from multiprocessing import Pool, Barrier
  11. from test_ds_aio_utils import report_results, task_log, task_barrier
  12. from deepspeed.accelerator import get_accelerator
  13. from deepspeed.ops.op_builder import AsyncIOBuilder
  14. def pre_handle(args, tid, read_op):
  15. io_string = "Read" if read_op else "Write"
  16. num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
  17. file = args.read_file if read_op else f'{args.write_file}.{tid}'
  18. io_parallel = args.io_parallel if args.io_parallel else 1
  19. handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
  20. args.overlap_events, io_parallel)
  21. task_log(tid, f'Created deepspeed aio handle')
  22. if args.gpu:
  23. buffer = torch.empty(num_bytes, dtype=torch.uint8, device=get_accelerator().device_name())
  24. else:
  25. if args.use_accelerator_pin_memory:
  26. buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu'))
  27. else:
  28. buffer = handle.new_cpu_locked_tensor(num_bytes, torch.empty(0, dtype=torch.uint8))
  29. task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
  30. ctxt = {}
  31. ctxt['file'] = file
  32. ctxt['num_bytes'] = num_bytes
  33. ctxt['handle'] = handle
  34. ctxt['buffer'] = buffer
  35. ctxt['elapsed_sec'] = 0
  36. task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
  37. return ctxt
  38. def pre_handle_read(pool_params):
  39. args, tid = pool_params
  40. ctxt = pre_handle(args, tid, True)
  41. return ctxt
  42. def pre_handle_write(pool_params):
  43. args, tid = pool_params
  44. ctxt = pre_handle(args, tid, False)
  45. return ctxt
  46. def post_handle(pool_params):
  47. _, _, ctxt = pool_params
  48. ctxt["buffer"].detach()
  49. ctxt["buffer"] = None
  50. return ctxt
  51. def main_parallel_read(pool_params):
  52. args, tid, ctxt = pool_params
  53. handle = ctxt['handle']
  54. start_time = time.time()
  55. ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True)
  56. assert ret != -1
  57. handle.wait()
  58. end_time = time.time()
  59. ctxt['elapsed_sec'] += end_time - start_time
  60. return ctxt
  61. def main_parallel_write(pool_params):
  62. args, tid, ctxt = pool_params
  63. handle = ctxt['handle']
  64. start_time = time.time()
  65. ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True)
  66. assert ret != -1
  67. handle.wait()
  68. end_time = time.time()
  69. ctxt['elapsed_sec'] += end_time - start_time
  70. return ctxt
  71. def main_handle_read(pool_parms):
  72. args, tid, ctxt = pool_parms
  73. handle = ctxt['handle']
  74. start_time = time.time()
  75. ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate)
  76. assert ret != -1
  77. end_time = time.time()
  78. ctxt['elapsed_sec'] += end_time - start_time
  79. return ctxt
  80. def main_handle_write(pool_parms):
  81. args, tid, ctxt = pool_parms
  82. handle = ctxt['handle']
  83. start_time = time.time()
  84. ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate)
  85. assert ret != -1
  86. end_time = time.time()
  87. ctxt['elapsed_sec'] += end_time - start_time
  88. return ctxt
  89. def get_schedule(args, read_op):
  90. schedule = {}
  91. if read_op:
  92. schedule['pre'] = pre_handle_read
  93. schedule['post'] = post_handle
  94. schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read
  95. else:
  96. schedule['pre'] = pre_handle_write
  97. schedule['post'] = post_handle
  98. schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write
  99. return schedule
  100. def _aio_handle_tasklet(pool_params):
  101. args, tid, read_op = pool_params
  102. # Create schedule
  103. schedule = get_schedule(args, read_op)
  104. task_log(tid, f'schedule = {schedule}')
  105. task_barrier(aio_barrier, args.threads)
  106. # Run pre task
  107. task_log(tid, f'running pre-task')
  108. ctxt = schedule["pre"]((args, tid))
  109. task_barrier(aio_barrier, args.threads)
  110. # Run main tasks in a loop
  111. ctxt["main_task_sec"] = 0
  112. for i in range(args.loops):
  113. task_log(tid, f'running main task {i}')
  114. start_time = time.time()
  115. ctxt = schedule["main"]((args, tid, ctxt))
  116. task_barrier(aio_barrier, args.threads)
  117. stop_time = time.time()
  118. ctxt["main_task_sec"] += stop_time - start_time
  119. # Run post task
  120. task_log(tid, f'running post-task')
  121. ctxt = schedule["post"]((args, tid, ctxt))
  122. task_barrier(aio_barrier, args.threads)
  123. return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
  124. def _init_tasklet(b):
  125. global aio_barrier
  126. aio_barrier = b
  127. def aio_handle_multiprocessing(args, read_op):
  128. b = Barrier(args.threads)
  129. pool_params = [(args, p, read_op) for p in range(args.threads)]
  130. with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p:
  131. pool_results = p.map(_aio_handle_tasklet, pool_params)
  132. report_results(args, read_op, pool_results)