ds_aio_basic.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Licensed under the MIT license.
  4. Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
  5. """
  6. import torch
  7. import os
  8. import time
  9. from deepspeed.ops.aio import AsyncIOBuilder
  10. from multiprocessing import Pool, Barrier
  11. from test_ds_aio_utils import report_results, task_log, task_barrier
  12. def pre_basic(args, tid, read_op):
  13. io_string = "Read" if read_op else "Write"
  14. num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
  15. file = args.read_file if read_op else f'{args.write_file}.{tid}'
  16. task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
  17. buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
  18. task_log(
  19. tid,
  20. f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}'
  21. )
  22. ctxt = {}
  23. ctxt['file'] = file
  24. ctxt['num_bytes'] = num_bytes
  25. ctxt['buffer'] = buffer
  26. ctxt['elapsed_sec'] = 0
  27. return ctxt
  28. def pre_basic_read(pool_params):
  29. args, tid = pool_params
  30. ctxt = pre_basic(args, tid, True)
  31. return ctxt
  32. def pre_basic_write(pool_params):
  33. args, tid = pool_params
  34. ctxt = pre_basic(args, tid, False)
  35. return ctxt
  36. def post_basic(pool_params):
  37. _, _, ctxt = pool_params
  38. ctxt["buffer"].detach()
  39. ctxt["buffer"] = None
  40. return ctxt
  41. def main_basic_read(pool_params):
  42. args, tid, ctxt = pool_params
  43. start_time = time.time()
  44. AsyncIOBuilder().load().aio_read(ctxt['buffer'],
  45. ctxt['file'],
  46. args.block_size,
  47. args.queue_depth,
  48. args.single_submit,
  49. args.overlap_events,
  50. args.validate)
  51. end_time = time.time()
  52. ctxt['elapsed_sec'] += end_time - start_time
  53. return ctxt
  54. def main_basic_write(pool_params):
  55. args, tid, ctxt = pool_params
  56. start_time = time.time()
  57. AsyncIOBuilder().load().aio_write(ctxt['buffer'],
  58. ctxt['file'],
  59. args.block_size,
  60. args.queue_depth,
  61. args.single_submit,
  62. args.overlap_events,
  63. args.validate)
  64. end_time = time.time()
  65. ctxt['elapsed_sec'] += end_time - start_time
  66. return ctxt
  67. def get_schedule(args, read_op):
  68. schedule = {}
  69. if read_op:
  70. schedule['pre'] = pre_basic_read
  71. schedule['post'] = post_basic
  72. schedule['main'] = main_basic_read
  73. else:
  74. schedule['pre'] = pre_basic_write
  75. schedule['post'] = post_basic
  76. schedule['main'] = main_basic_write
  77. return schedule
  78. def _aio_handle_tasklet(pool_params):
  79. args, tid, read_op = pool_params
  80. # Create schedule
  81. schedule = get_schedule(args, read_op)
  82. task_log(tid, f'schedule = {schedule}')
  83. task_barrier(aio_barrier, args.threads)
  84. # Run pre task
  85. task_log(tid, f'running pre-task')
  86. ctxt = schedule["pre"]((args, tid))
  87. task_barrier(aio_barrier, args.threads)
  88. # Run main tasks in a loop
  89. ctxt["main_task_sec"] = 0
  90. for i in range(args.loops):
  91. task_log(tid, f'running main task {i}')
  92. start_time = time.time()
  93. ctxt = schedule["main"]((args, tid, ctxt))
  94. task_barrier(aio_barrier, args.threads)
  95. stop_time = time.time()
  96. ctxt["main_task_sec"] += stop_time - start_time
  97. # Run post task
  98. task_log(tid, f'running post-task')
  99. ctxt = schedule["post"]((args, tid, ctxt))
  100. task_barrier(aio_barrier, args.threads)
  101. return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
  102. def _init_tasklet(b):
  103. global aio_barrier
  104. aio_barrier = b
  105. def aio_basic_multiprocessing(args, read_op):
  106. b = Barrier(args.threads)
  107. pool_params = [(args, p, read_op) for p in range(args.threads)]
  108. with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p:
  109. pool_results = p.map(_aio_handle_tasklet, pool_params)
  110. report_results(args, read_op, pool_results)