test_aio.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import pytest
  2. import os
  3. import filecmp
  4. import torch
  5. import deepspeed
  6. import deepspeed.comm as dist
  7. from deepspeed.ops.aio import AsyncIOBuilder
  8. from .common import distributed_test
  9. MEGA_BYTE = 1024**2
  10. BLOCK_SIZE = MEGA_BYTE
  11. QUEUE_DEPTH = 2
  12. IO_SIZE = 16 * MEGA_BYTE
  13. IO_PARALLEL = 2
  14. def _skip_if_no_aio():
  15. if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]:
  16. pytest.skip('Skip tests since async-io is not compatible')
  17. def _do_ref_write(tmpdir, index=0):
  18. file_suffix = f'{dist.get_rank()}_{index}'
  19. ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt')
  20. ref_buffer = os.urandom(IO_SIZE)
  21. with open(ref_file, 'wb') as f:
  22. f.write(ref_buffer)
  23. return ref_file, ref_buffer
  24. def _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device, index=0):
  25. file_suffix = f'{dist.get_rank()}_{index}'
  26. test_file = os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt')
  27. if cuda_device:
  28. test_buffer = torch.cuda.ByteTensor(list(ref_buffer))
  29. else:
  30. test_buffer = torch.ByteTensor(list(ref_buffer)).pin_memory()
  31. return test_file, test_buffer
  32. def _validate_handle_state(handle, single_submit, overlap_events):
  33. assert handle.get_single_submit() == single_submit
  34. assert handle.get_overlap_events() == overlap_events
  35. assert handle.get_thread_count() == IO_PARALLEL
  36. assert handle.get_block_size() == BLOCK_SIZE
  37. assert handle.get_queue_depth() == QUEUE_DEPTH
  38. @pytest.mark.parametrize('single_submit, overlap_events',
  39. [(False,
  40. False),
  41. (False,
  42. True),
  43. (True,
  44. False),
  45. (True,
  46. True)])
  47. def test_parallel_read(tmpdir, single_submit, overlap_events):
  48. _skip_if_no_aio()
  49. @distributed_test(world_size=[2])
  50. def _test_parallel_read(single_submit, overlap_events):
  51. ref_file, _ = _do_ref_write(tmpdir)
  52. aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory()
  53. h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE,
  54. QUEUE_DEPTH,
  55. single_submit,
  56. overlap_events,
  57. IO_PARALLEL)
  58. _validate_handle_state(h, single_submit, overlap_events)
  59. read_status = h.sync_pread(aio_buffer, ref_file)
  60. assert read_status == 1
  61. with open(ref_file, 'rb') as f:
  62. ref_buffer = list(f.read())
  63. assert ref_buffer == aio_buffer.tolist()
  64. _test_parallel_read(single_submit, overlap_events)
  65. @pytest.mark.parametrize('single_submit, overlap_events, cuda_device',
  66. [(False,
  67. False,
  68. False),
  69. (False,
  70. True,
  71. False),
  72. (True,
  73. False,
  74. False),
  75. (True,
  76. True,
  77. False),
  78. (False,
  79. False,
  80. True),
  81. (True,
  82. True,
  83. True)])
  84. def test_async_read(tmpdir, single_submit, overlap_events, cuda_device):
  85. _skip_if_no_aio()
  86. @distributed_test(world_size=[2])
  87. def _test_async_read(single_submit, overlap_events, cuda_device):
  88. ref_file, _ = _do_ref_write(tmpdir)
  89. if cuda_device:
  90. aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda')
  91. else:
  92. aio_buffer = torch.empty(IO_SIZE,
  93. dtype=torch.uint8,
  94. device='cpu').pin_memory()
  95. h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE,
  96. QUEUE_DEPTH,
  97. single_submit,
  98. overlap_events,
  99. IO_PARALLEL)
  100. _validate_handle_state(h, single_submit, overlap_events)
  101. read_status = h.async_pread(aio_buffer, ref_file)
  102. assert read_status == 0
  103. wait_status = h.wait()
  104. assert wait_status == 1
  105. with open(ref_file, 'rb') as f:
  106. ref_buffer = list(f.read())
  107. assert ref_buffer == aio_buffer.tolist()
  108. _test_async_read(single_submit, overlap_events, cuda_device)
  109. @pytest.mark.parametrize('single_submit, overlap_events',
  110. [(False,
  111. False),
  112. (False,
  113. True),
  114. (True,
  115. False),
  116. (True,
  117. True)])
  118. def test_parallel_write(tmpdir, single_submit, overlap_events):
  119. _skip_if_no_aio()
  120. @distributed_test(world_size=[2])
  121. def _test_parallel_write(single_submit, overlap_events):
  122. ref_file, ref_buffer = _do_ref_write(tmpdir)
  123. aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, False)
  124. h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE,
  125. QUEUE_DEPTH,
  126. single_submit,
  127. overlap_events,
  128. IO_PARALLEL)
  129. _validate_handle_state(h, single_submit, overlap_events)
  130. write_status = h.sync_pwrite(aio_buffer, aio_file)
  131. assert write_status == 1
  132. assert os.path.isfile(aio_file)
  133. filecmp.clear_cache()
  134. assert filecmp.cmp(ref_file, aio_file, shallow=False)
  135. _test_parallel_write(single_submit, overlap_events)
  136. @pytest.mark.parametrize('single_submit, overlap_events, cuda_device',
  137. [(False,
  138. False,
  139. False),
  140. (False,
  141. True,
  142. False),
  143. (True,
  144. False,
  145. False),
  146. (True,
  147. True,
  148. False),
  149. (False,
  150. False,
  151. True),
  152. (True,
  153. True,
  154. True)])
  155. def test_async_write(tmpdir, single_submit, overlap_events, cuda_device):
  156. _skip_if_no_aio()
  157. @distributed_test(world_size=[2])
  158. def _test_async_write(single_submit, overlap_events, cuda_device):
  159. ref_file, ref_buffer = _do_ref_write(tmpdir)
  160. aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device)
  161. h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE,
  162. QUEUE_DEPTH,
  163. single_submit,
  164. overlap_events,
  165. IO_PARALLEL)
  166. _validate_handle_state(h, single_submit, overlap_events)
  167. write_status = h.async_pwrite(aio_buffer, aio_file)
  168. assert write_status == 0
  169. wait_status = h.wait()
  170. assert wait_status == 1
  171. assert os.path.isfile(aio_file)
  172. filecmp.clear_cache()
  173. assert filecmp.cmp(ref_file, aio_file, shallow=False)
  174. _test_async_write(single_submit, overlap_events, cuda_device)
  175. @pytest.mark.parametrize('async_queue, cuda_device',
  176. [(2,
  177. False),
  178. (4,
  179. False),
  180. (2,
  181. True),
  182. (4,
  183. True)])
  184. def test_async_queue_read(tmpdir, async_queue, cuda_device):
  185. _skip_if_no_aio()
  186. @distributed_test(world_size=[2])
  187. def _test_async_queue_read(async_queue, cuda_device):
  188. ref_files = []
  189. for i in range(async_queue):
  190. f, _ = _do_ref_write(tmpdir, i)
  191. ref_files.append(f)
  192. aio_buffers = []
  193. for i in range(async_queue):
  194. if cuda_device:
  195. buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda')
  196. else:
  197. buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory()
  198. aio_buffers.append(buf)
  199. single_submit = True
  200. overlap_events = True
  201. h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE,
  202. QUEUE_DEPTH,
  203. single_submit,
  204. overlap_events,
  205. IO_PARALLEL)
  206. _validate_handle_state(h, single_submit, overlap_events)
  207. for i in range(async_queue):
  208. read_status = h.async_pread(aio_buffers[i], ref_files[i])
  209. assert read_status == 0
  210. wait_status = h.wait()
  211. assert wait_status == async_queue
  212. for i in range(async_queue):
  213. with open(ref_files[i], 'rb') as f:
  214. ref_buffer = list(f.read())
  215. assert ref_buffer == aio_buffers[i].tolist()
  216. _test_async_queue_read(async_queue, cuda_device)
  217. @pytest.mark.parametrize('async_queue, cuda_device',
  218. [(2,
  219. False),
  220. (7,
  221. False),
  222. (2,
  223. True),
  224. (7,
  225. True)])
  226. def test_async_queue_write(tmpdir, async_queue, cuda_device):
  227. _skip_if_no_aio()
  228. @distributed_test(world_size=[2])
  229. def _test_async_queue_write(async_queue, cuda_device):
  230. ref_files = []
  231. ref_buffers = []
  232. for i in range(async_queue):
  233. f, buf = _do_ref_write(tmpdir, i)
  234. ref_files.append(f)
  235. ref_buffers.append(buf)
  236. aio_files = []
  237. aio_buffers = []
  238. for i in range(async_queue):
  239. f, buf = _get_test_file_and_buffer(tmpdir, ref_buffers[i], cuda_device, i)
  240. aio_files.append(f)
  241. aio_buffers.append(buf)
  242. single_submit = True
  243. overlap_events = True
  244. h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE,
  245. QUEUE_DEPTH,
  246. single_submit,
  247. overlap_events,
  248. IO_PARALLEL)
  249. _validate_handle_state(h, single_submit, overlap_events)
  250. for i in range(async_queue):
  251. read_status = h.async_pwrite(aio_buffers[i], aio_files[i])
  252. assert read_status == 0
  253. wait_status = h.wait()
  254. assert wait_status == async_queue
  255. for i in range(async_queue):
  256. assert os.path.isfile(aio_files[i])
  257. filecmp.clear_cache()
  258. assert filecmp.cmp(ref_files[i], aio_files[i], shallow=False)
  259. _test_async_queue_write(async_queue, cuda_device)