From eb9b750830ea628033d2afb81241ac74c2d16914 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Thu, 23 Mar 2023 22:14:52 +0000 Subject: [PATCH] format and guard --- .../test_mscclpp.cpython-39-pytest-7.2.0.pyc | Bin 1652 -> 1662 bytes python/src/_py_mscclpp.cpp | 200 +++++++++--------- 2 files changed, 100 insertions(+), 100 deletions(-) diff --git a/python/mscclpp/__pycache__/test_mscclpp.cpython-39-pytest-7.2.0.pyc b/python/mscclpp/__pycache__/test_mscclpp.cpython-39-pytest-7.2.0.pyc index 91fce504cdddcd4888539c3ccd55220efff36cce..4fd716c340f36b9eb65253eff985cdcfc07b71d3 100644 GIT binary patch delta 103 zcmeyu^N)uwk(ZZ?0SNXVl}Yhn-N^Tyh5Z&wQDR>9t;syB(e|wIC8@Z4G_?L~fir+auH@7$hD6a=rrO8qx0aQ{X2_mFGg!1G$tTzCU)*2iD delta 76 zcmeyz^M!{mk(ZZ?0SKZF$)x;d*~s^ug*}R;C^0WPYBCRNv>Z!GYH^7cP`;Q2q>539 av5L<*KR35H1SqaIxt~>xQE~Dz)*Ap90u;Ug diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index e708fd47..bfce09d3 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -1,129 +1,129 @@ +#include #include #include -#include #include #include #include -#include #include +#include namespace nb = nanobind; using namespace nb::literals; // This is a poorman's substitute for std::format, which is a C++20 feature. -template -std::string string_format( const std::string& format, Args ... args ) -{ - int size_s = std::snprintf( nullptr, 0, format.c_str(), args ... ) + 1; // Extra space for '\0' - if( size_s <= 0 ){ throw std::runtime_error( "Error during formatting." ); } - auto size = static_cast( size_s ); - std::unique_ptr buf( new char[ size ] ); - std::snprintf( buf.get(), size, format.c_str(), args ... ); - return std::string( buf.get(), buf.get() + size - 1 ); // We don't want the '\0' inside +template +std::string string_format(const std::string &format, Args... args) { + int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + + 1; // Extra space for '\0' + if (size_s <= 0) { + throw std::runtime_error("Error during formatting."); + } + auto size = static_cast(size_s); + std::unique_ptr buf(new char[size]); + std::snprintf(buf.get(), size, format.c_str(), args...); + return std::string(buf.get(), + buf.get() + size - 1); // We don't want the '\0' inside } -template -Val maybe(mscclppResult_t status, Val val, const std::string& format, Args ... args) { - switch (status) { - case mscclppSuccess: - return val; +template +Val maybe(mscclppResult_t status, Val val, const std::string &format, + Args... args) { + switch (status) { + case mscclppSuccess: + return val; - case mscclppUnhandledCudaError: - case mscclppSystemError: - case mscclppInternalError: - case mscclppRemoteError: - case mscclppInProgress: - case mscclppNumResults: - throw std::runtime_error(string_format(format, args ...)); + case mscclppUnhandledCudaError: + case mscclppSystemError: + case mscclppInternalError: + case mscclppRemoteError: + case mscclppInProgress: + case mscclppNumResults: + throw std::runtime_error(string_format(format, args...)); - case mscclppInvalidArgument: - case mscclppInvalidUsage: - default: - throw std::invalid_argument(string_format(format, args ...)); - } + case mscclppInvalidArgument: + case mscclppInvalidUsage: + default: + throw std::invalid_argument(string_format(format, args...)); + } } struct MscclppComm { mscclppComm_t internal; }; - NB_MODULE(_py_mscclpp, m) { - m.doc() = "Python bindings for MSCCLPP"; + m.doc() = "Python bindings for MSCCLPP: which is not NCCL"; - m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES; + m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES; - nb::class_(m, "MscclppUniqueId") - .def_static("from_context", []() { - mscclppUniqueId uniqueId; - return maybe( - mscclppGetUniqueId(&uniqueId), - uniqueId, - "Failed to get MSCCLP Unique Id." - ); - }) - .def_static("from_bytes", [](nb::bytes source) { - if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) { - throw std::invalid_argument( - string_format( - "Requires exactly %d bytes; found %d", - MSCCLPP_UNIQUE_ID_BYTES, - source.size() - ) - ); - } + nb::class_(m, "MscclppUniqueId") + .def_static( + "from_context", + []() { + mscclppUniqueId uniqueId; + return maybe(mscclppGetUniqueId(&uniqueId), uniqueId, + "Failed to get MSCCLP Unique Id."); + }, + nb::call_guard()) + .def_static("from_bytes", + [](nb::bytes source) { + if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) { + throw std::invalid_argument(string_format( + "Requires exactly %d bytes; found %d", + MSCCLPP_UNIQUE_ID_BYTES, source.size())); + } - mscclppUniqueId uniqueId; - std::memcpy(uniqueId.internal, source.c_str(), sizeof(uniqueId.internal)); - return uniqueId; - }) - .def("bytes", [](mscclppUniqueId id){ - return nb::bytes(id.internal, sizeof(id.internal)); - }); + mscclppUniqueId uniqueId; + std::memcpy(uniqueId.internal, source.c_str(), + sizeof(uniqueId.internal)); + return uniqueId; + }) + .def("bytes", [](mscclppUniqueId id) { + return nb::bytes(id.internal, sizeof(id.internal)); + }); nb::class_(m, "MscclppComm") - .def_static( - "init_rank_from_address", - [](const std::string &address, int rank, int world_size) { - MscclppComm comm = { 0 }; - return maybe( - mscclppCommInitRank(&comm.internal, world_size, rank, address.c_str()), - comm, - "Failed to initialize comms: %s rank=%d world_size=%d", - address, - rank, - world_size); - }, - "address"_a, "rank"_a, "world_size"_a, - "Initialize comms given an IP address, rank, and world_size" - ) - .def_static("init_rank_from_id", [](const mscclppUniqueId &id, int rank, int world_size) { - MscclppComm comm = { 0 }; - return maybe( - mscclppCommInitRankFromId(&comm.internal, world_size, id, rank), - comm, - "Failed to initialize comms: %02X%s rank=%d world_size=%d", - id.internal, - rank, - world_size); - }) - .def("close", [](MscclppComm &comm) { - maybe( - mscclppCommDestroy(comm.internal), - nb::none(), - "Failed to close comm channel" - ); - comm.internal = 0; - }) - .def("__del__", [](MscclppComm &comm) { - maybe( - mscclppCommDestroy(comm.internal), - nb::none(), - "Failed to close comm channel" - ); - comm.internal = 0; - }); - + .def_static( + "init_rank_from_address", + [](const std::string &address, int rank, int world_size) { + MscclppComm comm = {0}; + return maybe(mscclppCommInitRank(&comm.internal, world_size, rank, + address.c_str()), + comm, + "Failed to initialize comms: %s rank=%d world_size=%d", + address, rank, world_size); + }, + nb::call_guard(), "address"_a, "rank"_a, + "world_size"_a, + "Initialize comms given an IP address, rank, and world_size") + .def_static( + "init_rank_from_id", + [](const mscclppUniqueId &id, int rank, int world_size) { + MscclppComm comm = {0}; + return maybe( + mscclppCommInitRankFromId(&comm.internal, world_size, id, rank), + comm, + "Failed to initialize comms: %02X%s rank=%d world_size=%d", + id.internal, rank, world_size); + }, + nb::call_guard(), "id"_a, "rank"_a, + "world_size"_a, + "Initialize comms given u UniqueID, rank, and world_size") + .def( + "close", + [](MscclppComm &comm) { + maybe(mscclppCommDestroy(comm.internal), nb::none(), + "Failed to close comm channel"); + comm.internal = 0; + }, + nb::call_guard()) + .def( + "__del__", + [](MscclppComm &comm) { + maybe(mscclppCommDestroy(comm.internal), nb::none(), + "Failed to close comm channel"); + comm.internal = 0; + }, + nb::call_guard()); } -