deepspeed_aio_thread.cpp 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. /*
  5. Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
  6. */
  7. #include "deepspeed_aio_thread.h"
  8. using namespace std;
  9. io_op_desc_t::io_op_desc_t(const bool read_op,
  10. const torch::Tensor& buffer,
  11. const int fd,
  12. const char* filename,
  13. const long long int num_bytes,
  14. const bool validate)
  15. : _read_op(read_op),
  16. _buffer(buffer),
  17. _fd(fd),
  18. _filename(filename),
  19. _num_bytes(num_bytes),
  20. _validate(validate)
  21. {
  22. _cpu_buffer = _buffer.is_cuda() ? _buffer.to(torch::kCPU).pin_memory() : _buffer;
  23. _contiguous_buffer = _cpu_buffer.contiguous();
  24. }
  25. char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
  26. void io_op_desc_t::fini()
  27. {
  28. if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
  29. }
  30. deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config)
  31. : _tid(tid),
  32. _aio_config(aio_config),
  33. _aio_ctxt(new aio_context(aio_config._block_size, aio_config._queue_depth)),
  34. _time_to_exit(false)
  35. {
  36. }
  37. deepspeed_aio_thread_t::~deepspeed_aio_thread_t() {}
  38. void deepspeed_aio_thread_t::run()
  39. {
  40. while (true) {
  41. std::shared_ptr<struct io_op_desc_t> next_io_op = nullptr;
  42. {
  43. std::unique_lock<std::mutex> lock(_work_sync._mutex);
  44. _work_sync._cond_var.wait(lock,
  45. [this] { return (!_work_queue.empty() || _time_to_exit); });
  46. if (!_work_queue.empty()) {
  47. next_io_op = _work_queue.front();
  48. _work_queue.pop();
  49. }
  50. }
  51. if (next_io_op) {
  52. const auto base_offset = next_io_op->_num_bytes * _tid;
  53. std::unique_ptr<io_xfer_ctxt> xfer_ctxt(new io_xfer_ctxt(
  54. next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr()));
  55. if (_aio_config._overlap_events) {
  56. do_aio_operation_overlap(
  57. next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
  58. } else {
  59. do_aio_operation_sequential(
  60. next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
  61. }
  62. {
  63. std::lock_guard<std::mutex> lock(_complete_sync._mutex);
  64. _complete_queue.push(next_io_op);
  65. }
  66. _complete_sync._cond_var.notify_one();
  67. }
  68. if (_time_to_exit) { break; }
  69. }
  70. }