Merge main branch

This commit is contained in:
Qinghua Zhou
2026-04-13 22:06:38 +00:00
223 changed files with 14129 additions and 4492 deletions

View File

@@ -4,6 +4,10 @@
add_subdirectory(csrc)
add_subdirectory(test)
target_compile_definitions(mscclpp_py PRIVATE
$<$<BOOL:${MSCCLPP_DISABLE_NB_LEAK_WARNINGS}>:MSCCLPP_DISABLE_NB_LEAK_WARNINGS>
)
add_custom_target(pytest_lib_copy ALL
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/_mscclpp.*.so
@@ -12,4 +16,4 @@ add_custom_target(pytest_lib_copy ALL
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/_ext.*.so
${CMAKE_CURRENT_SOURCE_DIR}/test/_cpp
DEPENDS mscclpp_py mscclpp_py_test
)
)

View File

@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
if(MSCCLPP_USE_ROCM)
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
endif()
install(TARGETS mscclpp_py LIBRARY DESTINATION .)

View File

@@ -16,14 +16,16 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_algorithm(nb::module_& m) {
nb::enum_<CollectiveBufferMode>(m, "CollectiveBufferMode")
nb::enum_<CollectiveBufferMode>(m, "CppCollectiveBufferMode")
.value("ANY", CollectiveBufferMode::Any)
.value("IN_PLACE", CollectiveBufferMode::InPlace)
.value("OUT_OF_PLACE", CollectiveBufferMode::OutOfPlace);
nb::enum_<AlgorithmType>(m, "AlgorithmType").value("NATIVE", AlgorithmType::Native).value("DSL", AlgorithmType::DSL);
nb::enum_<AlgorithmType>(m, "CppAlgorithmType")
.value("NATIVE", AlgorithmType::Native)
.value("DSL", AlgorithmType::DSL);
nb::enum_<CommResult>(m, "CommResult")
nb::enum_<CommResult>(m, "CppCommResult")
.value("COMM_SUCCESS", CommResult::CommSuccess)
.value("COMM_UNHANDLED_CUDA_ERROR", CommResult::CommUnhandledCudaError)
.value("COMM_SYSTEM_ERROR", CommResult::CommSystemError)
@@ -34,13 +36,13 @@ void register_algorithm(nb::module_& m) {
.value("COMM_IN_PROGRESS", CommResult::CommInProgress)
.value("COMM_NUM_RESULTS", CommResult::CommNumResults);
nb::enum_<ReduceOp>(m, "ReduceOp")
nb::enum_<ReduceOp>(m, "CppReduceOp")
.value("SUM", ReduceOp::SUM)
.value("MIN", ReduceOp::MIN)
.value("NOP", ReduceOp::NOP);
auto algorithmClass =
nb::class_<Algorithm>(m, "Algorithm")
nb::class_<Algorithm>(m, "CppAlgorithm")
.def_static(
"from_native_capsule",
[](nb::capsule cap) {
@@ -58,6 +60,12 @@ void register_algorithm(nb::module_& m) {
.def_prop_ro("name", &Algorithm::name)
.def_prop_ro("collective", &Algorithm::collective)
.def_prop_ro("message_range", &Algorithm::messageRange)
.def(
"set_message_size_range",
[](Algorithm& self, size_t minMessageSize, size_t maxMessageSize) {
self.setMessageSizeRange(minMessageSize, maxMessageSize);
},
nb::arg("min_message_size"), nb::arg("max_message_size"))
.def_prop_ro("tags", &Algorithm::tags)
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
.def_prop_ro("constraint", &Algorithm::constraint)
@@ -67,16 +75,19 @@ void register_algorithm(nb::module_& m) {
"execute",
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock,
std::unordered_map<std::string, uintptr_t> extras) {
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory,
std::unordered_map<std::string, uintptr_t> extras, int32_t accumDtype) {
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
nBlocks, nThreadsPerBlock, extras);
nBlocks, nThreadsPerBlock, symmetricMemory, extras,
static_cast<DataType>(accumDtype));
},
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0,
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>());
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false,
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>(),
nb::arg("accum_dtype") = static_cast<int32_t>(DataType::AUTO))
.def("reset", &Algorithm::reset);
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")
.def(nb::init<>())
@@ -84,21 +95,21 @@ void register_algorithm(nb::module_& m) {
.def_rw("world_size", &Algorithm::Constraint::worldSize)
.def_rw("n_ranks_per_node", &Algorithm::Constraint::nRanksPerNode);
nb::class_<AlgorithmBuilder>(m, "AlgorithmBuilder").def("build", &AlgorithmBuilder::build);
nb::class_<AlgorithmBuilder>(m, "CppAlgorithmBuilder").def("build", &AlgorithmBuilder::build);
nb::class_<DslAlgorithm, Algorithm>(m, "DslAlgorithm")
nb::class_<DslAlgorithm, Algorithm>(m, "CppDslAlgorithm")
.def(nb::init<std::string, ExecutionPlan, std::unordered_map<std::string, uint64_t>, Algorithm::Constraint>(),
nb::arg("id"), nb::arg("plan"), nb::arg("tags") = std::unordered_map<std::string, uint64_t>(),
nb::arg("constraint") = Algorithm::Constraint())
.def("build", &DslAlgorithm::build);
nb::class_<AlgorithmCollection>(m, "AlgorithmCollection")
nb::class_<AlgorithmCollection>(m, "CppAlgorithmCollection")
.def("register_algorithm", &AlgorithmCollection::registerAlgorithm, nb::arg("collective"), nb::arg("algo_name"),
nb::arg("algorithm"))
.def("get_algorithms_by_collective", &AlgorithmCollection::getAlgorithmsByCollective, nb::arg("collective"))
.def("to_list", &AlgorithmCollection::getAllAlgorithms);
nb::class_<CollectiveRequest>(m, "CollectiveRequest")
nb::class_<CollectiveRequest>(m, "CppCollectiveRequest")
.def_ro("world_size", &CollectiveRequest::worldSize)
.def_ro("n_ranks_per_node", &CollectiveRequest::nRanksPerNode)
.def_ro("rank", &CollectiveRequest::rank)
@@ -107,8 +118,22 @@ void register_algorithm(nb::module_& m) {
.def_prop_ro("output_buffer",
[](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.outputBuffer); })
.def_ro("message_size", &CollectiveRequest::messageSize)
.def_prop_ro("stream", [](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.stream); })
.def_prop_ro("collective", [](const CollectiveRequest& self) { return self.collective; })
.def_ro("dtype", &CollectiveRequest::dtype)
.def_prop_ro("hints", [](const CollectiveRequest& self) { return self.hints; })
.def("buffer_mode", &CollectiveRequest::bufferMode);
m.def(
"cpp_get_flag_buffer",
[]() {
auto [buffer, size] = getFlagBuffer();
uintptr_t ptr = reinterpret_cast<uintptr_t>(buffer.get());
// Transfer shared_ptr ownership into a capsule so Python's GC manages the lifetime.
auto prevent = std::make_unique<std::shared_ptr<void>>(std::move(buffer));
nb::capsule owner(prevent.get(), [](void* p) noexcept { delete static_cast<std::shared_ptr<void>*>(p); });
prevent.release(); // capsule now owns the pointer
return nb::make_tuple(ptr, size, owner);
},
"Get the default flag buffer. Returns a tuple of (buffer_ptr, buffer_size, owner).");
}

View File

@@ -32,21 +32,25 @@ extern void register_algorithm_collection_builder(nb::module_& m);
template <typename T>
void def_shared_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("shared_future_") + typestr;
std::string pyclass_name = std::string("CppSharedFuture_") + typestr;
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
}
void register_core(nb::module_& m) {
m.def("version", &version);
nb::enum_<DataType>(m, "DataType")
nb::enum_<DataType>(m, "CppDataType")
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16);
.value("bfloat16", DataType::BFLOAT16)
.value("float8_e4m3", DataType::FLOAT8_E4M3)
.value("float8_e5m2", DataType::FLOAT8_E5M2)
.value("uint8", DataType::UINT8)
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
nb::class_<Bootstrap>(m, "Bootstrap")
nb::class_<Bootstrap>(m, "CppBootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
@@ -71,7 +75,7 @@ void register_core(nb::module_& m) {
.def("recv", static_cast<void (Bootstrap::*)(std::vector<char>&, int, int)>(&Bootstrap::recv), nb::arg("data"),
nb::arg("peer"), nb::arg("tag"));
nb::class_<UniqueId>(m, "UniqueId")
nb::class_<UniqueId>(m, "CppUniqueId")
.def(nb::init<>())
.def("__setstate__",
[](UniqueId& self, nb::bytes b) {
@@ -81,7 +85,7 @@ void register_core(nb::module_& m) {
.def("__getstate__",
[](const UniqueId& self) { return nb::bytes(reinterpret_cast<const char*>(self.data()), UniqueIdBytes); });
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
nb::class_<TcpBootstrap, Bootstrap>(m, "CppTcpBootstrap")
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
.def_static(
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
@@ -93,7 +97,7 @@ void register_core(nb::module_& m) {
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
nb::call_guard<nb::gil_scoped_release>(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30);
nb::enum_<Transport>(m, "Transport")
nb::enum_<Transport>(m, "CppTransport")
.value("Unknown", Transport::Unknown)
.value("CudaIpc", Transport::CudaIpc)
.value("IB0", Transport::IB0)
@@ -106,7 +110,7 @@ void register_core(nb::module_& m) {
.value("IB7", Transport::IB7)
.value("NumTransports", Transport::NumTransports);
nb::class_<TransportFlags>(m, "TransportFlags")
nb::class_<TransportFlags>(m, "CppTransportFlags")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def("has", &TransportFlags::has, nb::arg("transport"))
@@ -130,12 +134,12 @@ void register_core(nb::module_& m) {
.def(nb::self == nb::self)
.def(nb::self != nb::self);
nb::enum_<DeviceType>(m, "DeviceType")
nb::enum_<DeviceType>(m, "CppDeviceType")
.value("Unknown", DeviceType::Unknown)
.value("CPU", DeviceType::CPU)
.value("GPU", DeviceType::GPU);
nb::class_<Device>(m, "Device")
nb::class_<Device>(m, "CppDevice")
.def(nb::init<>())
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
@@ -147,24 +151,33 @@ void register_core(nb::module_& m) {
return ss.str();
});
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
nb::enum_<EndpointConfig::Ib::Mode>(m, "CppIbMode")
.value("Default", EndpointConfig::Ib::Mode::Default)
.value("Host", EndpointConfig::Ib::Mode::Host)
.value("HostNoAtomic", EndpointConfig::Ib::Mode::HostNoAtomic);
nb::class_<EndpointConfig::Ib>(m, "CppEndpointConfigIb")
.def(nb::init<>())
.def(nb::init<int, int, int, int, int, int, int>(), nb::arg("device_index") = -1,
.def(nb::init<int, int, int, int, int, int, int, int, EndpointConfig::Ib::Mode>(), nb::arg("device_index") = -1,
nb::arg("port") = EndpointConfig::Ib::DefaultPort,
nb::arg("gid_index") = EndpointConfig::Ib::DefaultGidIndex,
nb::arg("max_cq_size") = EndpointConfig::Ib::DefaultMaxCqSize,
nb::arg("max_cq_poll_num") = EndpointConfig::Ib::DefaultMaxCqPollNum,
nb::arg("max_send_wr") = EndpointConfig::Ib::DefaultMaxSendWr,
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend)
nb::arg("max_recv_wr") = EndpointConfig::Ib::DefaultMaxRecvWr,
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend,
nb::arg("mode") = EndpointConfig::Ib::Mode::Default)
.def_rw("device_index", &EndpointConfig::Ib::deviceIndex)
.def_rw("port", &EndpointConfig::Ib::port)
.def_rw("gid_index", &EndpointConfig::Ib::gidIndex)
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
.def_rw("max_recv_wr", &EndpointConfig::Ib::maxRecvWr)
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend)
.def_rw("mode", &EndpointConfig::Ib::mode);
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
nb::class_<RegisteredMemory>(m, "CppRegisteredMemory")
.def(nb::init<>())
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
.def("size", &RegisteredMemory::size)
@@ -172,7 +185,7 @@ void register_core(nb::module_& m) {
.def("serialize", &RegisteredMemory::serialize)
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
nb::class_<Endpoint>(m, "Endpoint")
nb::class_<Endpoint>(m, "CppEndpoint")
.def("config", &Endpoint::config)
.def("transport", &Endpoint::transport)
.def("device", &Endpoint::device)
@@ -180,7 +193,7 @@ void register_core(nb::module_& m) {
.def("serialize", &Endpoint::serialize)
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
nb::class_<Connection>(m, "Connection")
nb::class_<Connection>(m, "CppConnection")
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
nb::arg("size"))
.def(
@@ -197,7 +210,7 @@ void register_core(nb::module_& m) {
.def("local_device", &Connection::localDevice)
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
nb::class_<EndpointConfig>(m, "EndpointConfig")
nb::class_<EndpointConfig>(m, "CppEndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
@@ -223,12 +236,18 @@ void register_core(nb::module_& m) {
.def_prop_rw(
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
.def_prop_rw(
"ib_max_recv_wr", [](EndpointConfig& self) { return self.ib.maxRecvWr; },
[](EndpointConfig& self, int v) { self.ib.maxRecvWr = v; })
.def_prop_rw(
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
.def_prop_rw(
"ib_mode", [](EndpointConfig& self) { return self.ib.mode; },
[](EndpointConfig& self, EndpointConfig::Ib::Mode v) { self.ib.mode = v; })
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
nb::class_<Context>(m, "Context")
nb::class_<Context>(m, "CppContext")
.def_static("create", &Context::create)
.def(
"register_memory",
@@ -239,13 +258,13 @@ void register_core(nb::module_& m) {
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
nb::class_<SemaphoreStub>(m, "CppSemaphoreStub")
.def(nb::init<const Connection&>(), nb::arg("connection"))
.def("memory", &SemaphoreStub::memory)
.def("serialize", &SemaphoreStub::serialize)
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
nb::class_<Semaphore>(m, "Semaphore")
nb::class_<Semaphore>(m, "CppSemaphore")
.def(nb::init<>())
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("local_stub"), nb::arg("remote_stub"))
.def("connection", &Semaphore::connection)
@@ -256,7 +275,7 @@ void register_core(nb::module_& m) {
def_shared_future<Connection>(m, "Connection");
def_shared_future<Semaphore>(m, "Semaphore");
nb::class_<Communicator>(m, "Communicator")
nb::class_<Communicator>(m, "CppCommunicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
nb::arg("context") = nullptr)
.def("bootstrap", &Communicator::bootstrap)
@@ -289,6 +308,9 @@ void register_core(nb::module_& m) {
}
NB_MODULE(_mscclpp, m) {
#ifdef MSCCLPP_DISABLE_NB_LEAK_WARNINGS
nb::set_leak_warnings(false);
#endif
register_env(m);
register_error(m);
register_port_channel(m);
@@ -306,4 +328,4 @@ NB_MODULE(_mscclpp, m) {
// ext
register_algorithm_collection_builder(m);
}
}

View File

@@ -11,7 +11,7 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_env(nb::module_& m) {
nb::class_<Env>(m, "Env")
nb::class_<Env>(m, "CppEnv")
.def_ro("debug", &Env::debug)
.def_ro("debug_subsys", &Env::debugSubsys)
.def_ro("debug_file", &Env::debugFile)
@@ -20,9 +20,11 @@ void register_env(nb::module_& m) {
.def_ro("socket_family", &Env::socketFamily)
.def_ro("socket_ifname", &Env::socketIfname)
.def_ro("comm_id", &Env::commId)
.def_ro("execution_plan_dir", &Env::executionPlanDir)
.def_ro("ibv_mode", &Env::ibvMode)
.def_ro("cache_dir", &Env::cacheDir)
.def_ro("npkit_dump_dir", &Env::npkitDumpDir)
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream);
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream)
.def_ro("ib_gid_index", &Env::ibGidIndex);
m.def("env", &env);
}

View File

@@ -11,18 +11,18 @@ using namespace mscclpp;
#define REGISTER_EXCEPTION_TRANSLATOR(name_) \
nb::register_exception_translator( \
[](const std::exception_ptr &p, void *payload) { \
[](const std::exception_ptr& p, void* payload) { \
try { \
std::rethrow_exception(p); \
} catch (const name_ &e) { \
PyErr_SetObject(reinterpret_cast<PyObject *>(payload), \
} catch (const name_& e) { \
PyErr_SetObject(reinterpret_cast<PyObject*>(payload), \
PyTuple_Pack(2, PyLong_FromLong(long(e.getErrorCode())), PyUnicode_FromString(e.what()))); \
} \
}, \
m.attr(#name_).ptr());
void register_error(nb::module_ &m) {
nb::enum_<ErrorCode>(m, "ErrorCode")
void register_error(nb::module_& m) {
nb::enum_<ErrorCode>(m, "CppErrorCode")
.value("SystemError", ErrorCode::SystemError)
.value("InternalError", ErrorCode::InternalError)
.value("RemoteError", ErrorCode::RemoteError)

View File

@@ -15,16 +15,16 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_executor(nb::module_& m) {
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
nb::enum_<PacketType>(m, "CppPacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
nb::class_<ExecutionPlan>(m, "CppExecutionPlan")
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
.def_prop_ro("name", [](const ExecutionPlan& self) -> std::string { return self.name(); })
.def_prop_ro("collective", [](const ExecutionPlan& self) -> std::string { return self.collective(); })
.def_prop_ro("min_message_size", [](const ExecutionPlan& self) -> size_t { return self.minMessageSize(); })
.def_prop_ro("max_message_size", [](const ExecutionPlan& self) -> size_t { return self.maxMessageSize(); });
nb::class_<Executor>(m, "Executor")
nb::class_<Executor>(m, "CppExecutor")
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
.def(
"execute",

View File

@@ -4,6 +4,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
@@ -15,7 +16,7 @@ using namespace mscclpp;
using namespace mscclpp::collective;
void register_algorithm_collection_builder(nb::module_& m) {
nb::class_<AlgorithmCollectionBuilder>(m, "AlgorithmCollectionBuilder")
nb::class_<AlgorithmCollectionBuilder>(m, "CppAlgorithmCollectionBuilder")
.def_static("get_instance", &AlgorithmCollectionBuilder::getInstance)
.def("add_algorithm_builder", &AlgorithmCollectionBuilder::addAlgorithmBuilder, nb::arg("builder"))
.def(
@@ -29,6 +30,6 @@ void register_algorithm_collection_builder(nb::module_& m) {
nb::arg("selector"))
.def("build", &AlgorithmCollectionBuilder::build)
.def("build_default_algorithms", &AlgorithmCollectionBuilder::buildDefaultAlgorithms, nb::arg("scratch_buffer"),
nb::arg("scratch_buffer_size"), nb::arg("rank"))
nb::arg("scratch_buffer_size"), nb::arg("flag_buffer"), nb::arg("flag_buffer_size"), nb::arg("rank"))
.def_static("reset", &AlgorithmCollectionBuilder::reset);
}

View File

@@ -9,7 +9,7 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_fifo(nb::module_& m) {
nb::class_<ProxyTrigger>(m, "ProxyTrigger")
nb::class_<ProxyTrigger>(m, "CppProxyTrigger")
.def_prop_rw(
"fst", [](const ProxyTrigger& self) { return self.fst; },
[](ProxyTrigger& self, uint64_t v) { self.fst = v; })
@@ -17,7 +17,7 @@ void register_fifo(nb::module_& m) {
"snd", [](const ProxyTrigger& self) { return self.snd; },
[](ProxyTrigger& self, uint64_t v) { self.snd = v; });
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
nb::class_<FifoDeviceHandle>(m, "CppFifoDeviceHandle")
.def_rw("triggers", &FifoDeviceHandle::triggers)
.def_rw("tail", &FifoDeviceHandle::tail)
.def_rw("head", &FifoDeviceHandle::head)
@@ -26,7 +26,7 @@ void register_fifo(nb::module_& m) {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<Fifo>(m, "Fifo")
nb::class_<Fifo>(m, "CppFifo")
.def(nb::init<int>(), nb::arg("size") = DEFAULT_FIFO_SIZE)
.def("poll", &Fifo::poll)
.def("pop", &Fifo::pop)

View File

@@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) {
return DLDataType{kDLBfloat, 16, 1};
} else if (type == "torch.float16") {
return DLDataType{kDLFloat, 16, 1};
} else if (type == "torch.float8_e4m3fn") {
return DLDataType{kDLFloat8_e4m3fn, 8, 1};
} else if (type == "torch.float8_e4m3fnuz") {
return DLDataType{kDLFloat8_e4m3fnuz, 8, 1};
} else if (type == "torch.float8_e5m2") {
return DLDataType{kDLFloat8_e5m2, 8, 1};
} else if (type == "torch.float8_e5m2fnuz") {
return DLDataType{kDLFloat8_e5m2fnuz, 8, 1};
} else if (type == "torch.uint8") {
return DLDataType{kDLUInt, 8, 1};
} else if (type == "fp8_e4m3b15") {
// No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes.
return DLDataType{kDLUInt, 8, 1};
} else {
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
}
@@ -101,7 +114,7 @@ static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType, std::v
void register_gpu_utils(nb::module_& m) {
m.def("is_nvls_supported", &isNvlsSupported);
nb::class_<GpuBuffer<char>>(m, "RawGpuBuffer")
nb::class_<GpuBuffer<char>>(m, "CppRawGpuBuffer")
.def(nb::init<size_t>(), nb::arg("nelems"))
.def("nelems", &GpuBuffer<char>::nelems)
.def("bytes", &GpuBuffer<char>::bytes)

View File

@@ -11,20 +11,20 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_memory_channel(nb::module_& m) {
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
nb::class_<BaseMemoryChannel>(m, "CppBaseMemoryChannel")
.def(nb::init<>())
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def("device_handle", &BaseMemoryChannel::deviceHandle);
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "CppBaseMemoryChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_", &BaseMemoryChannel::DeviceHandle::semaphore_)
.def_prop_ro("raw", [](const BaseMemoryChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<MemoryChannel>(m, "MemoryChannel")
nb::class_<MemoryChannel>(m, "CppMemoryChannel")
.def(nb::init<>())
.def(
"__init__",
@@ -42,7 +42,7 @@ void register_memory_channel(nb::module_& m) {
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
.def("device_handle", &MemoryChannel::deviceHandle);
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
nb::class_<MemoryChannel::DeviceHandle>(m, "CppMemoryChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)

View File

@@ -8,8 +8,8 @@
namespace nb = nanobind;
void register_npkit(nb::module_ &m) {
nb::module_ sub_m = m.def_submodule("npkit", "NPKit functions");
void register_npkit(nb::module_& m) {
nb::module_ sub_m = m.def_submodule("cpp_npkit", "NPKit functions");
sub_m.def("init", &NpKit::Init);
sub_m.def("dump", &NpKit::Dump);
sub_m.def("shutdown", &NpKit::Shutdown);

View File

@@ -6,8 +6,8 @@ int getDeviceNumaNode(int cudaDev);
void numaBind(int node);
}; // namespace mscclpp
void register_numa(nb::module_ &m) {
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
void register_numa(nb::module_& m) {
nb::module_ sub_m = m.def_submodule("cpp_numa", "numa functions");
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
sub_m.def("numa_bind", &mscclpp::numaBind);
}

View File

@@ -11,11 +11,11 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_port_channel(nb::module_& m) {
nb::class_<BaseProxyService>(m, "BaseProxyService")
nb::class_<BaseProxyService>(m, "CppBaseProxyService")
.def("start_proxy", &BaseProxyService::startProxy, nb::arg("blocking") = false)
.def("stop_proxy", &BaseProxyService::stopProxy);
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
nb::class_<ProxyService, BaseProxyService>(m, "CppProxyService")
.def(nb::init<int>(), nb::arg("fifo_size") = DEFAULT_FIFO_SIZE)
.def("start_proxy", &ProxyService::startProxy, nb::arg("blocking") = false)
.def("stop_proxy", &ProxyService::stopProxy)
@@ -31,13 +31,13 @@ void register_port_channel(nb::module_& m) {
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
nb::class_<BasePortChannel>(m, "BasePortChannel")
nb::class_<BasePortChannel>(m, "CppBasePortChannel")
.def(nb::init<>())
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"))
.def("device_handle", &BasePortChannel::deviceHandle);
nb::class_<BasePortChannel::DeviceHandle>(m, "BasePortChannelDeviceHandle")
nb::class_<BasePortChannel::DeviceHandle>(m, "CppBasePortChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_id_", &BasePortChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &BasePortChannel::DeviceHandle::semaphore_)
@@ -46,13 +46,13 @@ void register_port_channel(nb::module_& m) {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<PortChannel>(m, "PortChannel")
nb::class_<PortChannel>(m, "CppPortChannel")
.def(nb::init<>())
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
.def("device_handle", &PortChannel::deviceHandle);
nb::class_<PortChannel::DeviceHandle>(m, "PortChannelDeviceHandle")
nb::class_<PortChannel::DeviceHandle>(m, "CppPortChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_id_", &PortChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &PortChannel::DeviceHandle::semaphore_)

View File

@@ -10,7 +10,7 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_semaphore(nb::module_& m) {
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "CppHost2DeviceSemaphore");
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2DeviceSemaphore::connection)
@@ -25,7 +25,7 @@ void register_semaphore(nb::module_& m) {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
nb::class_<Host2HostSemaphore>(m, "CppHost2HostSemaphore")
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2HostSemaphore::connection)
@@ -34,7 +34,7 @@ void register_semaphore(nb::module_& m) {
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
nb::arg("max_spin_count") = 10000000);
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "CppMemoryDevice2DeviceSemaphore");
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
@@ -43,7 +43,6 @@ void register_semaphore(nb::module_& m) {
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {

View File

@@ -15,11 +15,11 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_nvls(nb::module_& m) {
nb::class_<SwitchChannel>(m, "SwitchChannel")
nb::class_<SwitchChannel>(m, "CppSwitchChannel")
.def("get_device_ptr", [](SwitchChannel* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &SwitchChannel::deviceHandle);
nb::class_<SwitchChannel::DeviceHandle>(m, "DeviceHandle")
nb::class_<SwitchChannel::DeviceHandle>(m, "CppSwitchChannelDeviceHandle")
.def(nb::init<>())
.def_rw("device_ptr", &SwitchChannel::DeviceHandle::devicePtr)
.def_rw("mc_ptr", &SwitchChannel::DeviceHandle::mcPtr)
@@ -28,7 +28,7 @@ void register_nvls(nb::module_& m) {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<NvlsConnection>(m, "NvlsConnection")
nb::class_<NvlsConnection>(m, "CppNvlsConnection")
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"));
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("all_ranks"),

View File

@@ -23,35 +23,37 @@ version = {
from ._core import *
from ._mscclpp import (
Device,
DeviceType,
Communicator,
Connection,
CppDevice as Device,
CppDeviceType as DeviceType,
CppCommunicator as Communicator,
CppConnection as Connection,
connect_nvls_collective,
EndpointConfig,
Fifo,
Semaphore,
Host2DeviceSemaphore,
Host2HostSemaphore,
numa,
ProxyService,
RegisteredMemory,
PortChannel,
MemoryChannel,
MemoryDevice2DeviceSemaphore,
TcpBootstrap,
Transport,
TransportFlags,
DataType,
ErrorCode,
Executor,
ExecutionPlan,
PacketType,
RawGpuBuffer,
ReduceOp,
CppEndpointConfig as EndpointConfig,
CppEndpointConfigIb as EndpointConfigIb,
CppIbMode as IbMode,
CppFifo as Fifo,
CppSemaphore as Semaphore,
CppHost2DeviceSemaphore as Host2DeviceSemaphore,
CppHost2HostSemaphore as Host2HostSemaphore,
cpp_numa as numa,
CppProxyService as ProxyService,
CppRegisteredMemory as RegisteredMemory,
CppPortChannel as PortChannel,
CppMemoryChannel as MemoryChannel,
CppMemoryDevice2DeviceSemaphore as MemoryDevice2DeviceSemaphore,
CppTcpBootstrap as TcpBootstrap,
CppTransport as Transport,
CppTransportFlags as TransportFlags,
CppDataType as DataType,
CppErrorCode as ErrorCode,
CppExecutor as Executor,
CppExecutionPlan as ExecutionPlan,
CppPacketType as PacketType,
CppRawGpuBuffer as RawGpuBuffer,
CppReduceOp as ReduceOp,
env,
is_nvls_supported,
npkit,
cpp_npkit as npkit,
)
__all__ = [
@@ -61,6 +63,8 @@ __all__ = [
"Connection",
"connect_nvls_collective",
"EndpointConfig",
"EndpointConfigIb",
"IbMode",
"ErrorCode",
"Fifo",
"Semaphore",

View File

@@ -6,7 +6,7 @@ import shutil
import argparse
from pathlib import Path
from mscclpp.language import default_algos as def_algo
from mscclpp import default_algos as def_algo
from mscclpp.language.collectives import *
from mscclpp.language.utils import AlgoSpec
@@ -57,7 +57,7 @@ default_algo_configs = [
def create_default_plans():
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp_default")
plan_dir = os.path.join(os.environ.get("MSCCLPP_CACHE_DIR", Path.home() / ".cache/mscclpp"), "default")
plan_path = Path(plan_dir)
if plan_path.exists():
shutil.rmtree(plan_path)

View File

@@ -5,9 +5,3 @@ from .algorithm import *
from .comm import *
from .compiler import *
from .buffer import *
__all__ = []
__all__ += algorithm.__all__
__all__ += comm.__all__
__all__ += compiler.__all__
__all__ += buffer.__all__

View File

@@ -4,18 +4,22 @@
from __future__ import annotations
from typing import Optional, Tuple, Dict
from functools import cached_property
import cupy as cp
from mscclpp._mscclpp import (
Algorithm as _Algorithm,
DslAlgorithm as _DslAlgorithm,
AlgorithmType as _AlgorithmType,
Communicator,
CollectiveBufferMode,
DataType,
Executor,
ExecutionPlan,
ReduceOp,
CppAlgorithm,
CppDslAlgorithm,
CppAlgorithmType,
CppCommunicator,
CppCollectiveBufferMode,
CppDataType,
CppExecutor,
CppExecutionPlan,
CppReduceOp,
CppAlgorithmBuilder,
CppAlgorithmCollection,
cpp_get_flag_buffer,
)
__all__ = ["Algorithm", "AlgorithmBuilder", "AlgorithmCollection"]
@@ -45,7 +49,7 @@ class Algorithm:
"""
def __init__(self, world_size: int = 0, n_ranks_per_node: int = 0):
self._constraint = _Algorithm.Constraint(world_size, n_ranks_per_node)
self._constraint = CppAlgorithm.Constraint(world_size, n_ranks_per_node)
@property
def world_size(self) -> int:
@@ -58,23 +62,23 @@ class Algorithm:
def __init__(
self,
id: Optional[str] = None,
execution_plan: Optional[ExecutionPlan] = None,
native_handle: Optional[_Algorithm] = None,
execution_plan: Optional[CppExecutionPlan] = None,
native_handle: Optional[CppAlgorithm] = None,
tags: Optional[Dict[str, int]] = None,
constraint: Optional[Constraint] = None,
):
if execution_plan is not None:
self._algorithm = _DslAlgorithm(
self._algorithm = CppDslAlgorithm(
id,
execution_plan,
tags=tags if tags is not None else {},
constraint=constraint._constraint if constraint is not None else _Algorithm.Constraint(),
constraint=constraint._constraint if constraint is not None else CppAlgorithm.Constraint(),
)
elif native_handle is not None:
self._algorithm = native_handle
@classmethod
def create_from_native_handle(cls, handle: _Algorithm):
def create_from_native_handle(cls, handle: CppAlgorithm):
"""Create an Algorithm instance from a native C++ algorithm handle.
Args:
@@ -97,7 +101,7 @@ class Algorithm:
Returns:
A new Algorithm instance wrapping the algorithm from the capsule.
"""
handle = _Algorithm.from_native_capsule(obj)
handle = CppAlgorithm.from_native_capsule(obj)
return cls(native_handle=handle)
@cached_property
@@ -110,18 +114,31 @@ class Algorithm:
"""The collective operation this algorithm implements (e.g., "allreduce", "allgather")."""
return self._algorithm.collective
@cached_property
@property
def message_size_range(self) -> Tuple[int, int]:
"""The valid message size range (min_size, max_size) in bytes."""
return (self._algorithm.message_range[0], self._algorithm.message_range[1])
def set_message_size_range(self, min_message_size: int, max_message_size: int):
"""Set the valid message size range in bytes.
Args:
min_message_size: Minimum supported message size in bytes.
max_message_size: Maximum supported message size in bytes.
Only supported for native algorithms. Raises TypeError for DSL algorithms.
"""
if self.is_dsl_algorithm():
raise TypeError("set_message_size_range is only supported for native algorithms")
self._algorithm.set_message_size_range(min_message_size, max_message_size)
@cached_property
def tags(self) -> Dict[str, int]:
"""Dictionary of tag names to tag values for algorithm selection hints."""
return self._algorithm.tags
@cached_property
def buffer_mode(self) -> CollectiveBufferMode:
def buffer_mode(self) -> CppCollectiveBufferMode:
"""The buffer mode supported by this algorithm (IN_PLACE, OUT_OF_PLACE, or ANY)."""
return self._algorithm.buffer_mode
@@ -131,7 +148,7 @@ class Algorithm:
Returns:
True if this algorithm is defined using DSL/execution plan, False otherwise.
"""
if self._algorithm.type == _AlgorithmType.DSL:
if self._algorithm.type == CppAlgorithmType.DSL:
return True
return False
@@ -141,24 +158,26 @@ class Algorithm:
Returns:
True if this algorithm is implemented natively, False otherwise.
"""
if self._algorithm.type == _AlgorithmType.NATIVE:
if self._algorithm.type == CppAlgorithmType.NATIVE:
return True
return False
def execute(
self,
comm: Communicator,
comm: CppCommunicator,
input_buffer: int,
output_buffer: int,
input_size: int,
output_size: int,
dtype: DataType,
op: ReduceOp = ReduceOp.NOP,
dtype: CppDataType,
op: CppReduceOp = CppReduceOp.NOP,
stream: int = 0,
executor: Optional[Executor] = None,
executor: Optional[CppExecutor] = None,
nblocks=0,
nthreads_per_block=0,
symmetric_memory: bool = False,
extras: Optional[Dict[str, int]] = None,
accum_dtype: Optional[CppDataType] = None,
) -> int:
"""Execute the collective algorithm.
@@ -174,11 +193,16 @@ class Algorithm:
executor: The executor for DSL algorithms (required for DSL, optional for native).
nblocks: Number of CUDA blocks (0 for auto-selection).
nthreads_per_block: Number of threads per block (0 for auto-selection).
symmetric_memory: Whether to use symmetric memory optimization (default: False).
extras: Additional algorithm-specific parameters.
accum_dtype: Data type for accumulation during reduction. If None, defaults to
the same as dtype. Use DataType.float32 for high-precision FP8 accumulation.
Returns:
The result code (0 for success).
"""
merged_extras = dict(extras) if extras is not None else {}
accum_dtype = accum_dtype if accum_dtype is not None else dtype
return self._algorithm.execute(
comm,
int(input_buffer),
@@ -191,12 +215,18 @@ class Algorithm:
executor,
nblocks,
nthreads_per_block,
extras if extras is not None else {},
symmetric_memory,
merged_extras,
int(accum_dtype),
)
def reset(self):
"""Reset the internal state of the algorithm, if applicable."""
self._algorithm.reset()
class AlgorithmBuilder:
def __init__(self, algorithm_builder: _AlgorithmBuilder):
def __init__(self, algorithm_builder: CppAlgorithmBuilder):
self._algorithm_builder = algorithm_builder
def build(self) -> Algorithm:
@@ -204,7 +234,7 @@ class AlgorithmBuilder:
class AlgorithmCollection:
def __init__(self, native_collection: _AlgorithmCollection):
def __init__(self, native_collection: CppAlgorithmCollection):
self._native_collection = native_collection
self._algorithms = [Algorithm.create_from_native_handle(algo) for algo in self._native_collection.to_list()]
@@ -228,3 +258,24 @@ class AlgorithmCollection:
"""Register an algorithm for a collective operation."""
self._native_collection.register_algorithm(collective, algo_name, algorithm._algorithm)
self._algorithms.append(algorithm)
_flag_buffer_cache = None
def get_flag_buffer() -> cp.ndarray:
"""Get the default flag buffer for algorithm selection.
This buffer is used internally by default algorithms to store selection flags.
It is allocated as a shared GPU buffer and can be accessed from Python.
The result is cached so all callers share the same buffer.
Returns:
A CuPy array representing the flag buffer on the GPU.
"""
global _flag_buffer_cache
if _flag_buffer_cache is None:
buffer_ptr, buffer_size, owner = cpp_get_flag_buffer()
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer_ptr, buffer_size, owner), 0)
_flag_buffer_cache = cp.ndarray((buffer_size // 4,), dtype=cp.uint32, memptr=memptr)
return _flag_buffer_cache

View File

@@ -6,7 +6,7 @@ from typing import Union, Tuple
import cupy as cp
import numpy as np
from mscclpp._mscclpp import RawGpuBuffer
from mscclpp._mscclpp import CppRawGpuBuffer
__all__ = ["GpuBuffer"]
@@ -25,6 +25,6 @@ class GpuBuffer(cp.ndarray):
if any(s <= 0 for s in shape):
raise ValueError("Shape must be positive.")
# Create the buffer
buffer = RawGpuBuffer(np.prod(shape) * np.dtype(dtype).itemsize)
buffer = CppRawGpuBuffer(np.prod(shape) * np.dtype(dtype).itemsize)
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer.data(), buffer.bytes(), buffer), 0)
return cp.ndarray(shape, dtype=dtype, strides=strides, order=order, memptr=memptr)

View File

@@ -6,21 +6,21 @@ from typing import Type
import cupy as cp
from mscclpp._mscclpp import (
Communicator,
Connection,
CppCommunicator,
CppConnection,
connect_nvls_collective,
EndpointConfig,
Semaphore,
ProxyService,
RegisteredMemory,
PortChannel,
MemoryChannel,
TcpBootstrap,
Transport,
TransportFlags,
CppEndpointConfig,
CppSemaphore,
CppProxyService,
CppRegisteredMemory,
CppPortChannel,
CppMemoryChannel,
CppTcpBootstrap,
CppTransport,
CppTransportFlags,
)
import mpi4py
import numpy as np
import pickle
from mscclpp.utils import is_torch_tensor
@@ -29,27 +29,47 @@ __all__ = ["CommGroup"]
class CommGroup:
def __init__(
self, mpi_comm: mpi4py.MPI.Comm = None, interfaceIpPortTrio: str = "", rank: int = None, size: int = None
self,
mpi_comm: "mpi4py.MPI.Comm" = None,
torch_group: "dist.ProcessGroup" = None,
interfaceIpPortTrio: str = "",
rank: int = None,
size: int = None,
):
if interfaceIpPortTrio == "":
self.bootstrap = TcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
uniq_id = None
if mpi_comm.rank == 0:
# similar to NCCL's unique id
rank, size = (
(mpi_comm.Get_rank(), mpi_comm.Get_size())
if mpi_comm is not None
else (torch_group.rank(), torch_group.size())
)
self.bootstrap = CppTcpBootstrap.create(rank, size)
if rank == 0:
uniq_id = self.bootstrap.create_unique_id()
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
if mpi_comm is not None:
import mpi4py
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
else:
import torch
import torch.distributed as dist
if rank == 0:
uniq_id_global = uniq_id
pickled_data = pickle.dumps(uniq_id)
data_tensor = torch.frombuffer(bytearray(pickled_data), dtype=torch.uint8).clone()
else:
data_tensor = torch.zeros(256, dtype=torch.uint8)
dist.broadcast(data_tensor, src=0, group=torch_group)
uniq_id_global = pickle.loads(data_tensor.numpy().tobytes())
self.bootstrap.initialize(uniq_id_global)
elif mpi_comm:
# use this instead
self.bootstrap = TcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
self.bootstrap.initialize(interfaceIpPortTrio)
elif not interfaceIpPortTrio == "":
assert rank >= 0 and size >= 1
self.bootstrap = TcpBootstrap.create(rank, size)
self.bootstrap = CppTcpBootstrap.create(rank, size)
self.bootstrap.initialize(interfaceIpPortTrio)
else:
raise RuntimeError("Either the interface or mpi_group need to be specified")
self.communicator = Communicator(self.bootstrap)
self.communicator = CppCommunicator(self.bootstrap)
self.my_rank = self.bootstrap.get_rank()
self.nranks = self.bootstrap.get_n_ranks()
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
@@ -63,43 +83,43 @@ class CommGroup:
def recv(self, tensor: np.ndarray, peer: int, tag: int):
self.bootstrap.recv(tensor.ctypes.data, tensor.size * tensor.itemsize, peer, tag)
def my_ib_device(self, local_rank: int) -> Transport:
def my_ib_device(self, local_rank: int) -> CppTransport:
if local_rank == 0:
return Transport.IB0
return CppTransport.IB0
if local_rank == 1:
return Transport.IB1
return CppTransport.IB1
if local_rank == 2:
return Transport.IB2
return CppTransport.IB2
if local_rank == 3:
return Transport.IB3
return CppTransport.IB3
if local_rank == 4:
return Transport.IB4
return CppTransport.IB4
if local_rank == 5:
return Transport.IB5
return CppTransport.IB5
if local_rank == 6:
return Transport.IB6
return CppTransport.IB6
if local_rank == 7:
return Transport.IB7
return CppTransport.IB7
else:
assert False # only 8 IBs are supported
def make_connection(
self,
all_ranks: list[int],
endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport],
endpoints: CppEndpointConfig | CppTransport | dict[int, CppEndpointConfig] | dict[int, CppTransport],
use_switch: bool = False,
) -> dict[int, Connection]:
if type(endpoints) is Transport:
endpoints = EndpointConfig(endpoints)
) -> dict[int, CppConnection]:
if type(endpoints) is CppTransport:
endpoints = CppEndpointConfig(endpoints)
elif type(endpoints) is dict:
endpoints = {k: EndpointConfig(v) if type(v) is Transport else v for k, v in endpoints.items()}
endpoints = {k: CppEndpointConfig(v) if type(v) is CppTransport else v for k, v in endpoints.items()}
connections = {}
for rank in all_ranks:
if type(endpoints) is dict:
endpoint = endpoints[rank]
else:
endpoint = endpoints
if endpoint.transport == Transport.CudaIpc and use_switch:
if endpoint.transport == CppTransport.CudaIpc and use_switch:
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
else:
connections[rank] = self.communicator.connect(endpoint, rank)
@@ -107,8 +127,8 @@ class CommGroup:
return connections
def register_tensor_with_connections(
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
) -> dict[int, RegisteredMemory]:
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, CppConnection]
) -> dict[int, CppRegisteredMemory]:
local_reg_memory = self.register_local_memory(tensor, connections)
all_registered_memories = {}
all_registered_memories[self.my_rank] = local_reg_memory
@@ -121,8 +141,8 @@ class CommGroup:
return all_registered_memories
def _register_memory_with_connections(
self, memory: RegisteredMemory, connections: dict[int, Connection]
) -> dict[int, RegisteredMemory]:
self, memory: CppRegisteredMemory, connections: dict[int, CppConnection]
) -> dict[int, CppRegisteredMemory]:
all_registered_memories = {}
all_registered_memories[self.my_rank] = memory
future_memories = {}
@@ -133,18 +153,20 @@ class CommGroup:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
def make_semaphores(self, connections: dict[int, Connection]) -> dict[int, Semaphore]:
def make_semaphores(self, connections: dict[int, CppConnection]) -> dict[int, CppSemaphore]:
future_semaphores = {}
for rank in connections:
future_semaphores[rank] = self.communicator.build_semaphore(connections[rank], rank)
return {rank: future.get() for rank, future in future_semaphores.items()}
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
def make_memory_channels(
self, tensor: cp.ndarray, connections: dict[int, CppConnection]
) -> dict[int, CppMemoryChannel]:
semaphores = self.make_semaphores(connections)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
for rank in connections:
channels[rank] = MemoryChannel(
channels[rank] = CppMemoryChannel(
semaphores[rank], registered_memories[rank], registered_memories[self.my_rank]
)
return channels
@@ -152,9 +174,9 @@ class CommGroup:
def make_memory_channels_with_scratch(
self,
tensor: cp.ndarray,
registeredScratchBuffer: RegisteredMemory,
connections: dict[int, Connection],
) -> dict[int, MemoryChannel]:
registeredScratchBuffer: CppRegisteredMemory,
connections: dict[int, CppConnection],
) -> dict[int, CppMemoryChannel]:
semaphores = self.make_semaphores(connections)
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
channels = {}
@@ -162,17 +184,17 @@ class CommGroup:
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_registered_memory = self.communicator.register_memory(tensor_data_ptr, tensor_size, TransportFlags())
local_registered_memory = self.communicator.register_memory(tensor_data_ptr, tensor_size, CppTransportFlags())
scratch_data_ptr = registeredScratchBuffer.data()
for rank in connections:
channels[rank] = MemoryChannel(
channels[rank] = CppMemoryChannel(
semaphores[rank], registered_memories[rank], local_registered_memory, scratch_data_ptr
)
return channels
def make_port_channels(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
) -> dict[int, PortChannel]:
self, proxy_service: CppProxyService, tensor: cp.ndarray, connections: dict[int, CppConnection]
) -> dict[int, CppPortChannel]:
semaphores = self.make_semaphores(connections)
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
@@ -188,12 +210,12 @@ class CommGroup:
def make_port_channels_with_scratch(
self,
proxy_service: ProxyService,
proxy_service: CppProxyService,
tensor: cp.ndarray,
registeredScratchBuffer: RegisteredMemory,
connections: dict[int, Connection],
) -> dict[int, PortChannel]:
transport_flags = TransportFlags()
registeredScratchBuffer: CppRegisteredMemory,
connections: dict[int, CppConnection],
) -> dict[int, CppPortChannel]:
transport_flags = CppTransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = (
@@ -223,8 +245,8 @@ class CommGroup:
return channels
def register_semaphore_with_proxy(
self, proxy_service: ProxyService, connections: dict[int, Connection]
) -> dict[int, PortChannel]:
self, proxy_service: CppProxyService, connections: dict[int, CppConnection]
) -> dict[int, CppPortChannel]:
semaphores = self.make_semaphores(connections)
semaphore_ids = {}
for rank in semaphores:
@@ -235,7 +257,7 @@ class CommGroup:
return channels
def register_memory_with_proxy(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
self, proxy_service: CppProxyService, tensor: cp.ndarray, connections: dict[int, CppConnection]
) -> dict[int, int]:
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
@@ -243,8 +265,8 @@ class CommGroup:
memory_ids[rank] = proxy_service.add_memory(registered_memories[rank])
return memory_ids
def register_local_memory(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> RegisteredMemory:
transport_flags = TransportFlags()
def register_local_memory(self, tensor: cp.ndarray, connections: dict[int, CppConnection]) -> CppRegisteredMemory:
transport_flags = CppTransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = (

View File

@@ -26,9 +26,7 @@ from mscclpp.language.program import CollectiveProgram
from mscclpp.language.utils import AlgoSpec
from mscclpp.utils import get_device_arch
from mscclpp._mscclpp import (
ExecutionPlan,
)
from mscclpp._mscclpp import CppExecutionPlan, env
logging.basicConfig(level=logging.INFO)
@@ -51,7 +49,7 @@ class DslCompiler:
into execution plans that can be run on GPUs. The compiled plans are cached
to disk for reuse.
The cache location can be configured via the `MSCCLPP_EXECUTION_PLAN_DIR`
The cache location can be configured via the `MSCCLPP_CACHE_DIR`
environment variable (defaults to `~/.cache/mscclpp`).
Example:
@@ -138,7 +136,7 @@ class DslCompiler:
)
).hexdigest()
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp")
plan_dir = Path(env().cache_dir)
os.makedirs(plan_dir, exist_ok=True)
filename = f"{plan_id}.json"
plan_path = os.path.join(plan_dir, filename)
@@ -157,7 +155,7 @@ class DslCompiler:
os.remove(tmp_path)
except Exception:
Path(plan_path).unlink(missing_ok=True)
execution_plan = ExecutionPlan(plan_path, rank)
execution_plan = CppExecutionPlan(plan_path, rank)
return Algorithm(
id=plan_id,
execution_plan=execution_plan,
@@ -179,8 +177,8 @@ class NativeCodeCompiler:
based on the runtime environment. Compiled modules are cached to avoid
recompilation.
The cache location can be configured via the `MSCCLPP_NATIVE_CACHE_DIR`
environment variable (defaults to `~/.cache/mscclpp/native`).
The cache location can be configured via the `MSCCLPP_CACHE_DIR`
environment variable (defaults to `~/.cache/mscclpp`).
Attributes:
_is_hip: True if running on AMD/ROCm, False for NVIDIA/CUDA.
@@ -226,8 +224,7 @@ class NativeCodeCompiler:
"-L" + os.path.join(self._lib_home, "lib"),
"-lmscclpp",
]
cache_root = os.environ.get("MSCCLPP_NATIVE_CACHE_DIR", Path.home() / ".cache/mscclpp/native")
self._cache_dir = Path(cache_root)
self._cache_dir = Path(env().cache_dir) / "native"
self._cache_dir.mkdir(parents=True, exist_ok=True)
def _get_compiler(self) -> str:
@@ -283,7 +280,7 @@ class NativeCodeCompiler:
Note:
- The source file should include pybind11 bindings to expose functions.
- MSCCLPP headers are automatically included in the compilation.
- The module is cached in `MSCCLPP_NATIVE_CACHE_DIR` (default: ~/.cache/mscclpp/native).
- The module is cached in `MSCCLPP_CACHE_DIR` (default: ~/.cache/mscclpp).
- File locking is used to prevent race conditions during parallel compilation.
Example:

View File

@@ -3,12 +3,10 @@
from __future__ import annotations
from typing import Union
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_flag_buffer
import atexit
from mscclpp._mscclpp import (
AlgorithmCollectionBuilder as _AlgorithmCollectionBuilder,
)
from mscclpp._mscclpp import CppAlgorithmCollectionBuilder
__all__ = ["AlgorithmCollectionBuilder"]
@@ -24,13 +22,14 @@ class AlgorithmCollectionBuilder:
@classmethod
def reset(cls):
if cls._instance is not None:
_AlgorithmCollectionBuilder.reset()
CppAlgorithmCollectionBuilder.reset()
cls._instance = None
def __init__(self):
if not hasattr(self, "_initialized"):
self._builder = _AlgorithmCollectionBuilder.get_instance()
self._builder = CppAlgorithmCollectionBuilder.get_instance()
self._initialized = True
self._flag_buffer = None
def add_algorithm_builder(self, algorithm_builder: Union[AlgorithmBuilder, Algorithm]):
if isinstance(algorithm_builder, AlgorithmBuilder):
@@ -52,8 +51,17 @@ class AlgorithmCollectionBuilder:
collection = self._builder.build()
return AlgorithmCollection(collection)
def build_default_algorithms(self, scratch_buffer: int, scratch_buffer_size: int, rank: int) -> AlgorithmCollection:
native_collection = self._builder.build_default_algorithms(int(scratch_buffer), scratch_buffer_size, rank)
def build_default_algorithms(
self,
scratch_buffer: int,
scratch_buffer_size: int,
rank: int,
) -> AlgorithmCollection:
if self._flag_buffer is None:
self._flag_buffer = get_flag_buffer()
native_collection = self._builder.build_default_algorithms(
int(scratch_buffer), scratch_buffer_size, self._flag_buffer.data.ptr, self._flag_buffer.nbytes, rank
)
return AlgorithmCollection(native_collection)

View File

@@ -140,7 +140,7 @@ class MemoryChannel:
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = GetOperation(
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],

View File

@@ -534,6 +534,7 @@ class PutOperation(BaseOperation):
self.dst_buff = dst_buff
self.channel_ids = channel_ids
self.channel_type = channel_type
self.from_packet = from_packet
self.to_packet = to_packet
self.with_signal = with_signal
self.with_signal_and_flush = with_signal_and_flush
@@ -579,6 +580,25 @@ class PutOperation(BaseOperation):
with_signal=self.with_signal,
with_signal_and_flush=self.with_signal_and_flush,
)
elif (
isinstance(other, PutOperation)
and self.name == Instruction.read_put_packet
and self.name == other.name
and self.src_buff == other.src_buff
and self.channel_type == other.channel_type
and self.tbg_info == other.tbg_info
):
fused_operation = PutOperation(
src_buff=self.src_buff,
dst_buff=self.dst_buff + other.dst_buff,
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
tbg_info=self.tbg_info,
from_packet=self.from_packet,
to_packet=self.to_packet,
with_signal=self.with_signal,
with_signal_and_flush=self.with_signal_and_flush,
)
return fused_operation
@@ -725,7 +745,7 @@ class ReduceOperation(BaseOperation):
remote_dst_buff=self.remote_dst_buff + other.dst_buff,
channel_ids=self.channel_ids,
put_channel_ids=self.put_channel_ids + other.channel_ids,
channel_type=self.channel_type,
channel_type=other.channel_type,
reduce_operation=self.reduce_operation,
tbg_info=self.tbg_info,
packet=self.packet,

View File

@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language.channel import *
from mscclpp.language.rank import *
from mscclpp.language.general import *
from mscclpp.language.program import *
from mscclpp.language.collectives import *
def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size):
chunksperloop = 1
collective = AllGather(gpu_size, chunksperloop, True)
with CollectiveProgram(
name,
collective,
gpu_size,
protocol="LL",
num_threads_per_block=num_threads_per_block,
use_double_scratch_buffer=True,
min_message_size=min_message_size,
max_message_size=max_message_size,
):
# Creating Scratch Buffers
scratch_buffer = []
for gpu in range(gpu_size):
scratch_buffer.append(Buffer(gpu, 2 * gpu_size))
# Copying it to scratch buffer
for gpu in range(gpu_size):
rank = Rank(gpu)
scratch_offset = gpu_size
input_buffer = rank.get_input_buffer()
rank.copy_packets(
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1], input_buffer[0:1], tb=0
)
# Putting packets in the remote scratch buffer
for gpu in range(gpu_size):
rank = Rank(gpu)
output_buffer = rank.get_output_buffer()
for peer in range(1, gpu_size):
dst_rank = (gpu + peer) % gpu_size
ch = MemoryChannel(dst_rank, gpu)
tb = 0
ch.read_put_packets(
scratch_buffer[dst_rank][gpu : gpu + 1],
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1],
tb,
)
# Copying packets from local scratch buffer to local buffer
for gpu in range(gpu_size):
rank = Rank(gpu)
output_buffer = rank.get_output_buffer()
for peer in range(1, gpu_size):
dst_rank = (gpu + peer) % gpu_size
rank.unpack_packets(
output_buffer[dst_rank : dst_rank + 1],
scratch_buffer[gpu][dst_rank : dst_rank + 1],
tb=0,
)
print(JSON())
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--num_gpus", type=int, help="number of gpus")
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
args = parser.parse_args()
allgather_example(args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size)

View File

@@ -11,7 +11,7 @@ from typing import Any, Type, Union
import cupy as cp
import numpy as np
from mscclpp._mscclpp import DataType
from mscclpp._mscclpp import CppDataType as DataType
try:
import torch
@@ -192,5 +192,13 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType:
return DataType.int32
elif dtype == torch.bfloat16:
return DataType.bfloat16
# Hardware supports either OCP format or FNUZ format for float8.
# Mapping both to the same MSCClPP data type.
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
return DataType.float8_e5m2
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
return DataType.float8_e4m3
elif dtype == torch.uint8:
return DataType.uint8
else:
raise ValueError(f"Unknown data type: {dtype}")

View File

@@ -6,4 +6,5 @@ pytest
numpy
matplotlib
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
blake3
blake3
pybind11

View File

@@ -0,0 +1,10 @@
mpi4py
cupy
prettytable
netifaces
pytest
numpy
matplotlib
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
blake3
pybind11

View File

@@ -63,10 +63,13 @@ class MyProxyService {
};
NB_MODULE(_ext, m) {
#ifdef MSCCLPP_DISABLE_NB_LEAK_WARNINGS
nb::set_leak_warnings(false);
#endif
nb::class_<MyProxyService>(m, "MyProxyService")
.def(nb::init<int, int, int, nb::list, nb::list>(), nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"),
nb::arg("reg_mem_list"), nb::arg("sem_list"))
.def("fifo_device_handle", &MyProxyService::fifoDeviceHandle)
.def("start", &MyProxyService::start)
.def("stop", &MyProxyService::stop);
}
}

View File

@@ -11,7 +11,7 @@ from mscclpp import (
env,
)
from mscclpp import CommGroup, GpuBuffer
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
from mscclpp.utils import KernelBuilder, pack
import os
import struct

View File

@@ -0,0 +1,397 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Correctness test for FP8 allreduce with different accumulation types.
#
# Verifies that FP8 allreduce with higher-precision accumulation produces
# results at least as accurate as native FP8 accumulation, by comparing
# against a float32 reference.
#
# Usage:
# mpirun -np 8 pytest python/test/test_fp8_accum.py -v
import cupy as cp
import numpy as np
import pytest
from mscclpp import CommGroup, GpuBuffer, DataType, ReduceOp, is_nvls_supported
from mscclpp.ext import AlgorithmCollectionBuilder
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
# ---------------------------------------------------------------------------
# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7)
# ---------------------------------------------------------------------------
def e4m3fn_to_float(uint8_array):
"""Decode a cupy uint8 array of E4M3FN bit patterns to float32."""
bits = uint8_array.astype(cp.int32)
sign = (bits >> 7) & 1
exp = (bits >> 3) & 0xF
mant = bits & 0x7
# Normal: (-1)^s * 2^(exp-7) * (1 + mant/8)
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 7).astype(cp.int32))
# Subnormal (exp==0): (-1)^s * 2^(-6) * (mant/8)
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
result = cp.where(exp == 0, subnormal_val, normal_val)
result = cp.where(sign == 1, -result, result)
# Zero
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
# NaN: exp==15 & mant==7
nan_mask = (exp == 15) & (mant == 7)
result = cp.where(nan_mask, cp.float32(float("nan")), result)
return result
def float_to_e4m3fn(f32_array, chunk_size=65536):
"""Encode a cupy float32 array to uint8 E4M3FN bit patterns.
Uses a lookup-table approach: precompute all 128 positive E4M3FN values,
then find nearest match per element via chunked broadcast comparison.
"""
# Build lookup table of all 128 positive E4M3FN values (0x00..0x7F)
all_bytes = cp.arange(128, dtype=cp.uint8)
all_floats = e4m3fn_to_float(all_bytes) # (128,) float32
# Mark NaN entries as inf so they're never selected as nearest
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
# Clamp input and extract sign
clamped = f32_array.astype(cp.float32)
clamped = cp.clip(clamped, -448.0, 448.0)
signs = (clamped < 0).astype(cp.uint8)
absval = cp.abs(clamped)
result = cp.zeros(absval.shape, dtype=cp.uint8)
n = absval.size
absval_flat = absval.ravel()
result_flat = result.ravel()
for start in range(0, n, chunk_size):
end = min(start + chunk_size, n)
chunk = absval_flat[start:end]
# (chunk_size, 128) difference matrix
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
# Combine with sign bit
result = result_flat.reshape(absval.shape)
result = result | (signs << 7)
# Handle exact zero
result = cp.where(absval == 0, cp.uint8(0), result)
return result
# ---------------------------------------------------------------------------
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
# ---------------------------------------------------------------------------
def e4m3b15_to_float(uint8_array):
"""Decode a cupy uint8 array of E4M3B15 bit patterns to float32."""
bits = uint8_array.astype(cp.int32)
sign = (bits >> 7) & 1
exp = (bits >> 3) & 0xF
mant = bits & 0x7
# Normal: (-1)^s * 2^(exp-15) * (1 + mant/8)
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 15).astype(cp.int32))
# Subnormal (exp==0): (-1)^s * 2^(-14) * (mant/8)
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
result = cp.where(exp == 0, subnormal_val, normal_val)
result = cp.where(sign == 1, -result, result)
# Zero
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
# NaN: exp==15 or negative zero (0x80)
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
result = cp.where(nan_mask, cp.float32(float("nan")), result)
return result
def float_to_e4m3b15(f32_array, chunk_size=65536):
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
Same lookup-table approach as float_to_e4m3fn.
"""
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
all_bytes = cp.arange(128, dtype=cp.uint8)
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
# Mark NaN entries as inf so they're never selected as nearest
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
# Clamp input and extract sign
clamped = f32_array.astype(cp.float32)
clamped = cp.clip(clamped, -0.9375, 0.9375)
signs = (clamped < 0).astype(cp.uint8)
absval = cp.abs(clamped)
result = cp.zeros(absval.shape, dtype=cp.uint8)
n = absval.size
absval_flat = absval.ravel()
result_flat = result.ravel()
for start in range(0, n, chunk_size):
end = min(start + chunk_size, n)
chunk = absval_flat[start:end]
# (chunk_size, 128) difference matrix
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
# Combine with sign bit
result = result_flat.reshape(absval.shape)
result = result | (signs << 7)
# Handle exact zero
result = cp.where(absval == 0, cp.uint8(0), result)
return result
# ---------------------------------------------------------------------------
# Shared test helpers
# ---------------------------------------------------------------------------
def setup_algorithms(mpi_group):
"""Build default algorithms and return (comm_group, algo_map, scratch_buf)."""
comm_group = CommGroup(mpi_group.comm)
scratch = GpuBuffer(1 << 27, dtype=cp.uint8) # 128 MB
AlgorithmCollectionBuilder.reset()
builder = AlgorithmCollectionBuilder()
algorithms = builder.build_default_algorithms(
scratch_buffer=scratch.data.ptr,
scratch_buffer_size=scratch.nbytes,
rank=comm_group.my_rank,
)
algo_map = {a.name: a for a in algorithms}
return comm_group, algo_map, scratch
def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, nthreads_per_block=0):
"""Run allreduce in-place on buffer and return a copy of the result."""
ret = algo.execute(
comm=comm_group.communicator,
input_buffer=buffer.data.ptr,
output_buffer=buffer.data.ptr,
input_size=buffer.nbytes,
output_size=buffer.nbytes,
dtype=dtype,
op=ReduceOp.SUM,
stream=cp.cuda.get_current_stream().ptr,
nblocks=nblocks,
nthreads_per_block=nthreads_per_block,
symmetric_memory=True,
accum_dtype=accum_dtype,
)
cp.cuda.Device().synchronize()
assert ret == 0, f"Allreduce failed with error code {ret}"
return buffer.copy()
# ---------------------------------------------------------------------------
# Test: FP8 E4M3 accumulation correctness
# ---------------------------------------------------------------------------
@parametrize_mpi_groups(8)
@pytest.mark.parametrize(
"algo_name",
[
"default_allreduce_packet",
"default_allreduce_nvls_packet",
"default_allreduce_fullmesh",
"default_allreduce_rsag_zero_copy",
"default_allreduce_allpair_packet",
],
)
@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576])
def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
"""Verify that FP8 E4M3 allreduce with higher-precision accumulation is at
least as accurate as native FP8 accumulation, across all algorithm variants."""
rank = mpi_group.comm.rank
world_size = mpi_group.comm.size
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
if algo_name not in algo_map:
pytest.skip(f"{algo_name} not available")
if "nvls" in algo_name and not is_nvls_supported():
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
algo = algo_map[algo_name]
buf = GpuBuffer(size, dtype=cp.uint8)
accum_configs = [
("fp8_native", DataType.float8_e4m3),
("float16", DataType.float16),
("float32", DataType.float32),
]
# rsag_zero_copy and fullmesh need explicit block/thread counts
if "rsag" in algo_name:
nb = max(1, min(32, size // (world_size * 32)))
nt = 1024
elif "fullmesh" in algo_name:
nb = 35
nt = 512
else:
nb = 0
nt = 0
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
rng = np.random.RandomState(42 + rank)
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
src_f32 = cp.clip(src_f32, -240.0, 240.0)
src_fp8 = float_to_e4m3fn(src_f32)
# Copy into symmetric buffer
buf[:] = src_fp8
cp.cuda.Device().synchronize()
# Run allreduce
result = run_allreduce(
algo,
comm_group,
buf,
dtype=DataType.float8_e4m3,
accum_dtype=accum_dtype,
nblocks=nb,
nthreads_per_block=nt,
)
result_f32 = e4m3fn_to_float(result)
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
ref_f32 = cp.zeros(size, dtype=cp.float32)
for r in range(world_size):
rng_r = np.random.RandomState(42 + r)
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
rank_data = cp.clip(rank_data, -240.0, 240.0)
rank_data_fp8 = float_to_e4m3fn(rank_data)
ref_f32 += e4m3fn_to_float(rank_data_fp8)
# Compute errors
abs_err = cp.abs(result_f32 - ref_f32)
mean_abs_err = float(cp.mean(abs_err))
errors[accum_label] = mean_abs_err
# Reset between runs
algo.reset()
# Higher-precision accumulation should be at least as accurate as native fp8
assert (
errors["float16"] <= errors["fp8_native"] + 1e-6
), f"float16 accum ({errors['float16']:.6f}) worse than native ({errors['fp8_native']:.6f})"
assert (
errors["float32"] <= errors["fp8_native"] + 1e-6
), f"float32 accum ({errors['float32']:.6f}) worse than native ({errors['fp8_native']:.6f})"
# ---------------------------------------------------------------------------
# Test: FP8 E4M3B15 accumulation correctness
# ---------------------------------------------------------------------------
@parametrize_mpi_groups(8)
@pytest.mark.parametrize(
"algo_name",
[
"default_allreduce_packet",
"default_allreduce_nvls_packet",
"default_allreduce_rsag_zero_copy",
"default_allreduce_fullmesh",
"default_allreduce_allpair_packet",
],
)
@pytest.mark.parametrize("size", [1024, 4096, 65536])
def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
"""Verify that FP8 E4M3B15 allreduce with higher-precision accumulation is at
least as accurate as native E4M3B15 accumulation."""
rank = mpi_group.comm.rank
world_size = mpi_group.comm.size
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
if algo_name not in algo_map:
pytest.skip(f"{algo_name} not available")
if "nvls" in algo_name and not is_nvls_supported():
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
algo = algo_map[algo_name]
buf = GpuBuffer(size, dtype=cp.uint8)
accum_configs = [
("e4m3b15_native", DataType.float8_e4m3b15),
("float16", DataType.float16),
("float32", DataType.float32),
]
# rsag_zero_copy needs explicit block/thread counts, scaled to data size
if "rsag" in algo_name:
nb = max(1, min(32, size // (world_size * 32)))
nt = 1024
else:
nb = 0
nt = 0
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
rng = np.random.RandomState(42 + rank)
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
src_uint8 = raw | signs
# Fix negative zero -> positive zero
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
# Copy into symmetric buffer
buf[:] = src_uint8
cp.cuda.Device().synchronize()
# Run allreduce
result = run_allreduce(
algo,
comm_group,
buf,
dtype=DataType.float8_e4m3b15,
accum_dtype=accum_dtype,
nblocks=nb,
nthreads_per_block=nt,
)
# Decode result
result_f32 = e4m3b15_to_float(result)
# Compute float32 reference
ref_f32 = cp.zeros(size, dtype=cp.float32)
for r in range(world_size):
rng_r = np.random.RandomState(42 + r)
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
bits_r = raw_r | signs_r
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
ref_f32 += e4m3b15_to_float(bits_r)
# Clamp reference to e4m3b15 representable range
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
# Compute errors (only on valid entries)
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
errors[accum_label] = mean_abs_err
algo.reset()
# Higher-precision accumulation should be at least as accurate as native
assert (
errors["float16"] <= errors["e4m3b15_native"] + 1e-8
), f"float16 accum ({errors['float16']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
assert (
errors["float32"] <= errors["e4m3b15_native"] + 1e-8
), f"float32 accum ({errors['float32']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"

View File

@@ -162,13 +162,10 @@ def create_connection(group: CommGroup, connection_type: str):
def create_group_and_connection(mpi_group: MpiGroup, connection_type: str):
if (connection_type == "NVLink" or connection_type == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvlink/nvls for cross node")
if connection_type == "IB" and os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0":
pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1")
group = CommGroup(mpi_group.comm)
try:
connection = create_connection(group, connection_type)
except Error as e:
if connection_type == "IB" and e.args[0] == ErrorCode.InvalidUsage:
pytest.skip("IB not supported on this node")
raise
connection = create_connection(group, connection_type)
return group, connection
@@ -281,6 +278,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, connection_type: str,
@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores(mpi_group: MpiGroup):
if os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0":
pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1")
group = CommGroup(mpi_group.comm)
tran = group.my_ib_device(group.my_rank % 8)
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
@@ -301,6 +300,8 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
if os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0":
pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1")
group = CommGroup(mpi_group.comm)
tran = group.my_ib_device(group.my_rank % 8)
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))