ds_aio_handle.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Licensed under the MIT license.
  4. Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
  5. """
  6. import torch
  7. import os
  8. import time
  9. from multiprocessing import Pool, Barrier
  10. from deepspeed.ops.aio import AsyncIOBuilder
  11. from test_ds_aio_utils import report_results, task_log, task_barrier
  12. def pre_handle(args, tid, read_op):
  13. io_string = "Read" if read_op else "Write"
  14. num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
  15. file = args.read_file if read_op else f'{args.write_file}.{tid}'
  16. task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
  17. if args.gpu:
  18. buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cuda')
  19. else:
  20. buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
  21. task_log(
  22. tid,
  23. f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}'
  24. )
  25. io_parallel = args.io_parallel if args.io_parallel else 1
  26. handle = AsyncIOBuilder().load().aio_handle(args.block_size,
  27. args.queue_depth,
  28. args.single_submit,
  29. args.overlap_events,
  30. io_parallel)
  31. task_log(tid, f'created deepspeed aio handle')
  32. ctxt = {}
  33. ctxt['file'] = file
  34. ctxt['num_bytes'] = num_bytes
  35. ctxt['handle'] = handle
  36. ctxt['buffer'] = buffer
  37. ctxt['elapsed_sec'] = 0
  38. return ctxt
  39. def pre_handle_read(pool_params):
  40. args, tid = pool_params
  41. ctxt = pre_handle(args, tid, True)
  42. return ctxt
  43. def pre_handle_write(pool_params):
  44. args, tid = pool_params
  45. ctxt = pre_handle(args, tid, False)
  46. return ctxt
  47. def post_handle(pool_params):
  48. _, _, ctxt = pool_params
  49. ctxt["buffer"].detach()
  50. ctxt["buffer"] = None
  51. return ctxt
  52. def main_parallel_read(pool_params):
  53. args, tid, ctxt = pool_params
  54. handle = ctxt['handle']
  55. start_time = time.time()
  56. ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True)
  57. assert ret != -1
  58. handle.wait()
  59. end_time = time.time()
  60. ctxt['elapsed_sec'] += end_time - start_time
  61. return ctxt
  62. def main_parallel_write(pool_params):
  63. args, tid, ctxt = pool_params
  64. handle = ctxt['handle']
  65. start_time = time.time()
  66. ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True)
  67. assert ret != -1
  68. handle.wait()
  69. end_time = time.time()
  70. ctxt['elapsed_sec'] += end_time - start_time
  71. return ctxt
  72. def main_handle_read(pool_parms):
  73. args, tid, ctxt = pool_parms
  74. handle = ctxt['handle']
  75. start_time = time.time()
  76. ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate)
  77. assert ret != -1
  78. end_time = time.time()
  79. ctxt['elapsed_sec'] += end_time - start_time
  80. return ctxt
  81. def main_handle_write(pool_parms):
  82. args, tid, ctxt = pool_parms
  83. handle = ctxt['handle']
  84. start_time = time.time()
  85. ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate)
  86. assert ret != -1
  87. end_time = time.time()
  88. ctxt['elapsed_sec'] += end_time - start_time
  89. return ctxt
  90. def get_schedule(args, read_op):
  91. schedule = {}
  92. if read_op:
  93. schedule['pre'] = pre_handle_read
  94. schedule['post'] = post_handle
  95. schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read
  96. else:
  97. schedule['pre'] = pre_handle_write
  98. schedule['post'] = post_handle
  99. schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write
  100. return schedule
  101. def _aio_handle_tasklet(pool_params):
  102. args, tid, read_op = pool_params
  103. # Create schedule
  104. schedule = get_schedule(args, read_op)
  105. task_log(tid, f'schedule = {schedule}')
  106. task_barrier(aio_barrier, args.threads)
  107. # Run pre task
  108. task_log(tid, f'running pre-task')
  109. ctxt = schedule["pre"]((args, tid))
  110. task_barrier(aio_barrier, args.threads)
  111. # Run main tasks in a loop
  112. ctxt["main_task_sec"] = 0
  113. for i in range(args.loops):
  114. task_log(tid, f'running main task {i}')
  115. start_time = time.time()
  116. ctxt = schedule["main"]((args, tid, ctxt))
  117. task_barrier(aio_barrier, args.threads)
  118. stop_time = time.time()
  119. ctxt["main_task_sec"] += stop_time - start_time
  120. # Run post task
  121. task_log(tid, f'running post-task')
  122. ctxt = schedule["post"]((args, tid, ctxt))
  123. task_barrier(aio_barrier, args.threads)
  124. return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
  125. def _init_tasklet(b):
  126. global aio_barrier
  127. aio_barrier = b
  128. def aio_handle_multiprocessing(args, read_op):
  129. b = Barrier(args.threads)
  130. pool_params = [(args, p, read_op) for p in range(args.threads)]
  131. with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p:
  132. pool_results = p.map(_aio_handle_tasklet, pool_params)
  133. report_results(args, read_op, pool_results)