123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- // Copyright (c) Microsoft Corporation.
- // SPDX-License-Identifier: Apache-2.0
- // DeepSpeed Team
- #include <torch/extension.h>
- #include <oneapi/ccl.hpp>
- std::set<int> _comm_ids;
- std::set<int> _colors;
- ccl::vector_class<ccl::communicator> _ccl_comms;
- #define CCLCHECK(cmd) \
- do { \
- cmd; \
- } while (0)
- #define KVS_CREATE_SUCCESS 0
- #define KVS_CREATE_FAILURE -1
- bool is_initialized = 0;
- int world_rank = -1;
- int world_size = -1;
- ccl::shared_ptr_class<ccl::kvs> kvs;
- void initialize(int size, int rank, torch::Tensor& kvs_data)
- {
- if (is_initialized) return;
- world_size = size;
- world_rank = rank;
- is_initialized = 1;
- ccl::kvs::address_type main_addr;
- if (rank != 0) {
- memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
- kvs = ccl::create_kvs(main_addr);
- }
- _ccl_comms.emplace_back(ccl::create_communicator(size, rank, kvs));
- }
- /*
- rank == 0: create main kvs and return its address
- rank == else: return an empty address
- */
- std::vector<uint8_t> get_kvs_addr(int rank)
- {
- if (rank == 0) {
- kvs = ccl::create_main_kvs();
- ccl::kvs::address_type main_addr = kvs->get_address();
- auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
- return ccl_kvs_addr;
- } else {
- ccl::kvs::address_type main_addr;
- auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
- return ccl_kvs_addr;
- }
- }
- int get_rank(int group = 0) { return world_rank; }
- int get_world_size(int group = 0) { return world_size; }
- // Find the next ordered, unique value to a set. E.g. <0,1,2,7> --> 3
- int next_unique_val(std::set<int> s)
- {
- std::set<int>::iterator itr;
- // Base case. Add 0 to start of set.
- if (s.empty() || *s.begin() != 0) {
- return 0;
- // second base case where s = {0} (the case of s = {n != 0} is caught above)
- } else if (s.size() == 1) {
- return 1;
- } else {
- int prev_val = *s.begin();
- for (itr = std::next(s.begin()); itr != s.end(); itr++) {
- if (*itr != prev_val + 1) { return prev_val + 1; }
- prev_val = *itr;
- }
- return *(s.end()) + 1;
- }
- }
- py::object new_group(std::vector<int> ranks)
- {
- int comm_id = next_unique_val(_comm_ids);
- int color = next_unique_val(_colors);
- std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color
- << std::endl;
- }
- ccl::datatype get_ccl_datatype(c10::ScalarType type)
- {
- ccl::datatype ccl_type;
- switch (type) {
- case c10::ScalarType::Int: ccl_type = ccl::datatype::int32; break;
- case c10::ScalarType::Float: ccl_type = ccl::datatype::float32; break;
- case c10::ScalarType::Double: ccl_type = ccl::datatype::float64; break;
- case c10::ScalarType::BFloat16: ccl_type = ccl::datatype::bfloat16; break;
- case c10::ScalarType::Half: ccl_type = ccl::datatype::float16; break;
- default: ccl_type = ccl::datatype::int8;
- }
- return ccl_type;
- }
- ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input)
- {
- py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
- if (!py::isinstance(op, ReduceOp)) {
- throw std::runtime_error("Error: Op must be of type ReduceOp");
- }
- int op_val = py::int_(op.attr("value"));
- ccl::reduction ccl_op;
- if (input.scalar_type() == at::kBool) {
- if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
- // For bool tensors, map sum to max, which both represent a bitwise or.
- // This is to prevent overflow issues with sum, since we use uint8 to
- // represent a bool (see cclDataType mapping).
- ccl_op = ccl::reduction::max;
- } else if (op_val == (int)py::int_(ReduceOp.attr("AVG").attr("value"))) {
- throw std::runtime_error("Error: For bool tensors, op must be of type ReduceOp");
- }
- }
- if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
- ccl_op = ccl::reduction::sum;
- } else if (op_val == (int)py::int_(ReduceOp.attr("MIN").attr("value"))) {
- ccl_op = ccl::reduction::min;
- } else if (op_val == (int)py::int_(ReduceOp.attr("MAX").attr("value"))) {
- ccl_op = ccl::reduction::max;
- } else if (op_val == (int)py::int_(ReduceOp.attr("PRODUCT").attr("value"))) {
- ccl_op = ccl::reduction::prod;
- } else {
- throw std::runtime_error("Error: Unrecognized ReduceOp type");
- }
- return ccl_op;
- }
- ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
- ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
- void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
- {
- CCLCHECK(ccl::broadcast(data.data_ptr(),
- data.numel(),
- get_ccl_datatype(data.scalar_type()),
- src,
- _get_comm_from_group(group))
- .wait());
- }
- // TODO: implement torch's async_op behavior, document it.
- void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
- {
- CCLCHECK(ccl::allreduce(data.data_ptr(),
- data.data_ptr(),
- data.numel(),
- get_ccl_datatype(data.scalar_type()),
- get_ccl_reduce_op(op, data),
- _get_comm_from_group(group))
- .wait());
- }
- void all_reduce_caching(torch::Tensor& data,
- py::object op,
- std::string match_id,
- py::object group,
- bool async_op)
- {
- ccl::allreduce_attr attr = ccl::default_allreduce_attr;
- auto match_str = ccl::v1::string(match_id);
- attr.template set<ccl::operation_attr_id::to_cache>(true);
- attr.template set<ccl::operation_attr_id::match_id>(match_str);
- // To control this, use operation attribute and set true value for to_cache field and unique
- // string (for example, tensor name) for match_id field. Note that:
- // match_id should be the same for a specific communication operation across all ranks.
- // If the same tensor is a part of different communication operations, match_id should have
- // different values for each of these operations.
- CCLCHECK(ccl::allreduce(data.data_ptr(),
- data.data_ptr(),
- data.numel(),
- get_ccl_datatype(data.scalar_type()),
- get_ccl_reduce_op(op, data),
- _get_comm_from_group(group),
- attr)
- .wait());
- }
- void barrier(py::object group, bool async_op)
- {
- CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait());
- }
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
- {
- m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr");
- m.def("initialize", &initialize, "ccl initialize");
- m.def("get_rank", &get_rank, "get rank");
- m.def("get_world_size", &get_world_size, "get world size");
- m.def("broadcast", &broadcast, "ccl broadcast");
- m.def("all_reduce", &all_reduce, "ccl all_reduce");
- m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching");
- m.def("barrier", &barrier, "barrier");
- }
|