ccl.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <torch/extension.h>
  5. #include <oneapi/ccl.hpp>
  6. std::set<int> _comm_ids;
  7. std::set<int> _colors;
  8. ccl::vector_class<ccl::communicator> _ccl_comms;
  9. #define CCLCHECK(cmd) \
  10. do { \
  11. cmd; \
  12. } while (0)
  13. #define KVS_CREATE_SUCCESS 0
  14. #define KVS_CREATE_FAILURE -1
  15. bool is_initialized = 0;
  16. int world_rank = -1;
  17. int world_size = -1;
  18. ccl::shared_ptr_class<ccl::kvs> kvs;
  19. void initialize(int size, int rank, torch::Tensor& kvs_data)
  20. {
  21. if (is_initialized) return;
  22. world_size = size;
  23. world_rank = rank;
  24. is_initialized = 1;
  25. ccl::kvs::address_type main_addr;
  26. if (rank != 0) {
  27. memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
  28. kvs = ccl::create_kvs(main_addr);
  29. }
  30. _ccl_comms.emplace_back(ccl::create_communicator(size, rank, kvs));
  31. }
  32. /*
  33. rank == 0: create main kvs and return its address
  34. rank == else: return an empty address
  35. */
  36. std::vector<uint8_t> get_kvs_addr(int rank)
  37. {
  38. if (rank == 0) {
  39. kvs = ccl::create_main_kvs();
  40. ccl::kvs::address_type main_addr = kvs->get_address();
  41. auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
  42. return ccl_kvs_addr;
  43. } else {
  44. ccl::kvs::address_type main_addr;
  45. auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
  46. return ccl_kvs_addr;
  47. }
  48. }
  49. int get_rank(int group = 0) { return world_rank; }
  50. int get_world_size(int group = 0) { return world_size; }
  51. // Find the next ordered, unique value to a set. E.g. <0,1,2,7> --> 3
  52. int next_unique_val(std::set<int> s)
  53. {
  54. std::set<int>::iterator itr;
  55. // Base case. Add 0 to start of set.
  56. if (s.empty() || *s.begin() != 0) {
  57. return 0;
  58. // second base case where s = {0} (the case of s = {n != 0} is caught above)
  59. } else if (s.size() == 1) {
  60. return 1;
  61. } else {
  62. int prev_val = *s.begin();
  63. for (itr = std::next(s.begin()); itr != s.end(); itr++) {
  64. if (*itr != prev_val + 1) { return prev_val + 1; }
  65. prev_val = *itr;
  66. }
  67. return *(s.end()) + 1;
  68. }
  69. }
  70. py::object new_group(std::vector<int> ranks)
  71. {
  72. int comm_id = next_unique_val(_comm_ids);
  73. int color = next_unique_val(_colors);
  74. std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color
  75. << std::endl;
  76. }
  77. ccl::datatype get_ccl_datatype(c10::ScalarType type)
  78. {
  79. ccl::datatype ccl_type;
  80. switch (type) {
  81. case c10::ScalarType::Int: ccl_type = ccl::datatype::int32; break;
  82. case c10::ScalarType::Float: ccl_type = ccl::datatype::float32; break;
  83. case c10::ScalarType::Double: ccl_type = ccl::datatype::float64; break;
  84. case c10::ScalarType::BFloat16: ccl_type = ccl::datatype::bfloat16; break;
  85. case c10::ScalarType::Half: ccl_type = ccl::datatype::float16; break;
  86. default: ccl_type = ccl::datatype::int8;
  87. }
  88. return ccl_type;
  89. }
  90. ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input)
  91. {
  92. py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
  93. if (!py::isinstance(op, ReduceOp)) {
  94. throw std::runtime_error("Error: Op must be of type ReduceOp");
  95. }
  96. int op_val = py::int_(op.attr("value"));
  97. ccl::reduction ccl_op;
  98. if (input.scalar_type() == at::kBool) {
  99. if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
  100. // For bool tensors, map sum to max, which both represent a bitwise or.
  101. // This is to prevent overflow issues with sum, since we use uint8 to
  102. // represent a bool (see cclDataType mapping).
  103. ccl_op = ccl::reduction::max;
  104. } else if (op_val == (int)py::int_(ReduceOp.attr("AVG").attr("value"))) {
  105. throw std::runtime_error("Error: For bool tensors, op must be of type ReduceOp");
  106. }
  107. }
  108. if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
  109. ccl_op = ccl::reduction::sum;
  110. } else if (op_val == (int)py::int_(ReduceOp.attr("MIN").attr("value"))) {
  111. ccl_op = ccl::reduction::min;
  112. } else if (op_val == (int)py::int_(ReduceOp.attr("MAX").attr("value"))) {
  113. ccl_op = ccl::reduction::max;
  114. } else if (op_val == (int)py::int_(ReduceOp.attr("PRODUCT").attr("value"))) {
  115. ccl_op = ccl::reduction::prod;
  116. } else {
  117. throw std::runtime_error("Error: Unrecognized ReduceOp type");
  118. }
  119. return ccl_op;
  120. }
  121. ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
  122. ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
  123. void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
  124. {
  125. CCLCHECK(ccl::broadcast(data.data_ptr(),
  126. data.numel(),
  127. get_ccl_datatype(data.scalar_type()),
  128. src,
  129. _get_comm_from_group(group))
  130. .wait());
  131. }
  132. // TODO: implement torch's async_op behavior, document it.
  133. void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
  134. {
  135. CCLCHECK(ccl::allreduce(data.data_ptr(),
  136. data.data_ptr(),
  137. data.numel(),
  138. get_ccl_datatype(data.scalar_type()),
  139. get_ccl_reduce_op(op, data),
  140. _get_comm_from_group(group))
  141. .wait());
  142. }
  143. void all_reduce_caching(torch::Tensor& data,
  144. py::object op,
  145. std::string match_id,
  146. py::object group,
  147. bool async_op)
  148. {
  149. ccl::allreduce_attr attr = ccl::default_allreduce_attr;
  150. auto match_str = ccl::v1::string(match_id);
  151. attr.template set<ccl::operation_attr_id::to_cache>(true);
  152. attr.template set<ccl::operation_attr_id::match_id>(match_str);
  153. // To control this, use operation attribute and set true value for to_cache field and unique
  154. // string (for example, tensor name) for match_id field. Note that:
  155. // match_id should be the same for a specific communication operation across all ranks.
  156. // If the same tensor is a part of different communication operations, match_id should have
  157. // different values for each of these operations.
  158. CCLCHECK(ccl::allreduce(data.data_ptr(),
  159. data.data_ptr(),
  160. data.numel(),
  161. get_ccl_datatype(data.scalar_type()),
  162. get_ccl_reduce_op(op, data),
  163. _get_comm_from_group(group),
  164. attr)
  165. .wait());
  166. }
  167. void barrier(py::object group, bool async_op)
  168. {
  169. CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait());
  170. }
  171. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  172. {
  173. m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr");
  174. m.def("initialize", &initialize, "ccl initialize");
  175. m.def("get_rank", &get_rank, "get rank");
  176. m.def("get_world_size", &get_world_size, "get world size");
  177. m.def("broadcast", &broadcast, "ccl broadcast");
  178. m.def("all_reduce", &all_reduce, "ccl all_reduce");
  179. m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching");
  180. m.def("barrier", &barrier, "barrier");
  181. }