deepspeed_py_aio_handle.h 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. /*
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. Licensed under the MIT license.
  4. Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
  5. */
  6. #include <condition_variable>
  7. #include <memory>
  8. #include "deepspeed_aio_thread.h"
  9. struct deepspeed_aio_handle_t {
  10. std::unique_ptr<struct aio_context> _aio_ctxt;
  11. const bool _single_submit;
  12. const bool _overlap_events;
  13. const int _num_threads;
  14. deepspeed_aio_config_t _aio_config;
  15. std::vector<std::shared_ptr<struct deepspeed_aio_thread_t>> _thread_contexts;
  16. std::vector<std::thread> _threads;
  17. int _num_pending_ops;
  18. deepspeed_aio_handle_t(const int block_size,
  19. const int queue_depth,
  20. const bool single_submit,
  21. const bool overlap_events,
  22. const int num_threads);
  23. ~deepspeed_aio_handle_t();
  24. const int get_block_size() const;
  25. const int get_queue_depth() const;
  26. const bool get_single_submit() const;
  27. const bool get_overlap_events() const;
  28. const int get_thread_count() const;
  29. int read(torch::Tensor& buffer, const char* filename, const bool validate);
  30. int write(const torch::Tensor& buffer, const char* filename, const bool validate);
  31. int pread(const torch::Tensor& buffer,
  32. const char* filename,
  33. const bool validate,
  34. const bool async);
  35. int pwrite(const torch::Tensor& buffer,
  36. const char* filename,
  37. const bool validate,
  38. const bool async);
  39. int sync_pread(torch::Tensor& buffer, const char* filename);
  40. int sync_pwrite(const torch::Tensor& buffer, const char* filename);
  41. int async_pread(torch::Tensor& buffer, const char* filename);
  42. int async_pwrite(const torch::Tensor& buffer, const char* filename);
  43. int wait();
  44. void _stop_threads();
  45. void _schedule_aio_work(std::shared_ptr<struct io_op_desc_t> scheduled_op);
  46. std::shared_ptr<struct io_op_desc_t> _wait_for_aio_work();
  47. bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes);
  48. };