mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Merge main branch
This commit is contained in:
@@ -4,6 +4,10 @@
|
||||
add_subdirectory(csrc)
|
||||
add_subdirectory(test)
|
||||
|
||||
target_compile_definitions(mscclpp_py PRIVATE
|
||||
$<$<BOOL:${MSCCLPP_DISABLE_NB_LEAK_WARNINGS}>:MSCCLPP_DISABLE_NB_LEAK_WARNINGS>
|
||||
)
|
||||
|
||||
add_custom_target(pytest_lib_copy ALL
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/_mscclpp.*.so
|
||||
@@ -12,4 +16,4 @@ add_custom_target(pytest_lib_copy ALL
|
||||
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/_ext.*.so
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test/_cpp
|
||||
DEPENDS mscclpp_py mscclpp_py_test
|
||||
)
|
||||
)
|
||||
@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
if(MSCCLPP_USE_ROCM)
|
||||
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
|
||||
endif()
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
|
||||
@@ -16,14 +16,16 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_algorithm(nb::module_& m) {
|
||||
nb::enum_<CollectiveBufferMode>(m, "CollectiveBufferMode")
|
||||
nb::enum_<CollectiveBufferMode>(m, "CppCollectiveBufferMode")
|
||||
.value("ANY", CollectiveBufferMode::Any)
|
||||
.value("IN_PLACE", CollectiveBufferMode::InPlace)
|
||||
.value("OUT_OF_PLACE", CollectiveBufferMode::OutOfPlace);
|
||||
|
||||
nb::enum_<AlgorithmType>(m, "AlgorithmType").value("NATIVE", AlgorithmType::Native).value("DSL", AlgorithmType::DSL);
|
||||
nb::enum_<AlgorithmType>(m, "CppAlgorithmType")
|
||||
.value("NATIVE", AlgorithmType::Native)
|
||||
.value("DSL", AlgorithmType::DSL);
|
||||
|
||||
nb::enum_<CommResult>(m, "CommResult")
|
||||
nb::enum_<CommResult>(m, "CppCommResult")
|
||||
.value("COMM_SUCCESS", CommResult::CommSuccess)
|
||||
.value("COMM_UNHANDLED_CUDA_ERROR", CommResult::CommUnhandledCudaError)
|
||||
.value("COMM_SYSTEM_ERROR", CommResult::CommSystemError)
|
||||
@@ -34,13 +36,13 @@ void register_algorithm(nb::module_& m) {
|
||||
.value("COMM_IN_PROGRESS", CommResult::CommInProgress)
|
||||
.value("COMM_NUM_RESULTS", CommResult::CommNumResults);
|
||||
|
||||
nb::enum_<ReduceOp>(m, "ReduceOp")
|
||||
nb::enum_<ReduceOp>(m, "CppReduceOp")
|
||||
.value("SUM", ReduceOp::SUM)
|
||||
.value("MIN", ReduceOp::MIN)
|
||||
.value("NOP", ReduceOp::NOP);
|
||||
|
||||
auto algorithmClass =
|
||||
nb::class_<Algorithm>(m, "Algorithm")
|
||||
nb::class_<Algorithm>(m, "CppAlgorithm")
|
||||
.def_static(
|
||||
"from_native_capsule",
|
||||
[](nb::capsule cap) {
|
||||
@@ -58,6 +60,12 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_prop_ro("name", &Algorithm::name)
|
||||
.def_prop_ro("collective", &Algorithm::collective)
|
||||
.def_prop_ro("message_range", &Algorithm::messageRange)
|
||||
.def(
|
||||
"set_message_size_range",
|
||||
[](Algorithm& self, size_t minMessageSize, size_t maxMessageSize) {
|
||||
self.setMessageSizeRange(minMessageSize, maxMessageSize);
|
||||
},
|
||||
nb::arg("min_message_size"), nb::arg("max_message_size"))
|
||||
.def_prop_ro("tags", &Algorithm::tags)
|
||||
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
|
||||
.def_prop_ro("constraint", &Algorithm::constraint)
|
||||
@@ -67,16 +75,19 @@ void register_algorithm(nb::module_& m) {
|
||||
"execute",
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock,
|
||||
std::unordered_map<std::string, uintptr_t> extras) {
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory,
|
||||
std::unordered_map<std::string, uintptr_t> extras, int32_t accumDtype) {
|
||||
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
|
||||
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
|
||||
nBlocks, nThreadsPerBlock, extras);
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras,
|
||||
static_cast<DataType>(accumDtype));
|
||||
},
|
||||
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
|
||||
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>());
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>(),
|
||||
nb::arg("accum_dtype") = static_cast<int32_t>(DataType::AUTO))
|
||||
.def("reset", &Algorithm::reset);
|
||||
|
||||
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")
|
||||
.def(nb::init<>())
|
||||
@@ -84,21 +95,21 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_rw("world_size", &Algorithm::Constraint::worldSize)
|
||||
.def_rw("n_ranks_per_node", &Algorithm::Constraint::nRanksPerNode);
|
||||
|
||||
nb::class_<AlgorithmBuilder>(m, "AlgorithmBuilder").def("build", &AlgorithmBuilder::build);
|
||||
nb::class_<AlgorithmBuilder>(m, "CppAlgorithmBuilder").def("build", &AlgorithmBuilder::build);
|
||||
|
||||
nb::class_<DslAlgorithm, Algorithm>(m, "DslAlgorithm")
|
||||
nb::class_<DslAlgorithm, Algorithm>(m, "CppDslAlgorithm")
|
||||
.def(nb::init<std::string, ExecutionPlan, std::unordered_map<std::string, uint64_t>, Algorithm::Constraint>(),
|
||||
nb::arg("id"), nb::arg("plan"), nb::arg("tags") = std::unordered_map<std::string, uint64_t>(),
|
||||
nb::arg("constraint") = Algorithm::Constraint())
|
||||
.def("build", &DslAlgorithm::build);
|
||||
|
||||
nb::class_<AlgorithmCollection>(m, "AlgorithmCollection")
|
||||
nb::class_<AlgorithmCollection>(m, "CppAlgorithmCollection")
|
||||
.def("register_algorithm", &AlgorithmCollection::registerAlgorithm, nb::arg("collective"), nb::arg("algo_name"),
|
||||
nb::arg("algorithm"))
|
||||
.def("get_algorithms_by_collective", &AlgorithmCollection::getAlgorithmsByCollective, nb::arg("collective"))
|
||||
.def("to_list", &AlgorithmCollection::getAllAlgorithms);
|
||||
|
||||
nb::class_<CollectiveRequest>(m, "CollectiveRequest")
|
||||
nb::class_<CollectiveRequest>(m, "CppCollectiveRequest")
|
||||
.def_ro("world_size", &CollectiveRequest::worldSize)
|
||||
.def_ro("n_ranks_per_node", &CollectiveRequest::nRanksPerNode)
|
||||
.def_ro("rank", &CollectiveRequest::rank)
|
||||
@@ -107,8 +118,22 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_prop_ro("output_buffer",
|
||||
[](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.outputBuffer); })
|
||||
.def_ro("message_size", &CollectiveRequest::messageSize)
|
||||
.def_prop_ro("stream", [](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.stream); })
|
||||
.def_prop_ro("collective", [](const CollectiveRequest& self) { return self.collective; })
|
||||
.def_ro("dtype", &CollectiveRequest::dtype)
|
||||
.def_prop_ro("hints", [](const CollectiveRequest& self) { return self.hints; })
|
||||
.def("buffer_mode", &CollectiveRequest::bufferMode);
|
||||
|
||||
m.def(
|
||||
"cpp_get_flag_buffer",
|
||||
[]() {
|
||||
auto [buffer, size] = getFlagBuffer();
|
||||
uintptr_t ptr = reinterpret_cast<uintptr_t>(buffer.get());
|
||||
// Transfer shared_ptr ownership into a capsule so Python's GC manages the lifetime.
|
||||
auto prevent = std::make_unique<std::shared_ptr<void>>(std::move(buffer));
|
||||
nb::capsule owner(prevent.get(), [](void* p) noexcept { delete static_cast<std::shared_ptr<void>*>(p); });
|
||||
prevent.release(); // capsule now owns the pointer
|
||||
return nb::make_tuple(ptr, size, owner);
|
||||
},
|
||||
"Get the default flag buffer. Returns a tuple of (buffer_ptr, buffer_size, owner).");
|
||||
}
|
||||
@@ -32,21 +32,25 @@ extern void register_algorithm_collection_builder(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_shared_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("shared_future_") + typestr;
|
||||
std::string pyclass_name = std::string("CppSharedFuture_") + typestr;
|
||||
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
|
||||
}
|
||||
|
||||
void register_core(nb::module_& m) {
|
||||
m.def("version", &version);
|
||||
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
nb::enum_<DataType>(m, "CppDataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("uint8", DataType::UINT8)
|
||||
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
|
||||
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
nb::class_<Bootstrap>(m, "CppBootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
|
||||
@@ -71,7 +75,7 @@ void register_core(nb::module_& m) {
|
||||
.def("recv", static_cast<void (Bootstrap::*)(std::vector<char>&, int, int)>(&Bootstrap::recv), nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"));
|
||||
|
||||
nb::class_<UniqueId>(m, "UniqueId")
|
||||
nb::class_<UniqueId>(m, "CppUniqueId")
|
||||
.def(nb::init<>())
|
||||
.def("__setstate__",
|
||||
[](UniqueId& self, nb::bytes b) {
|
||||
@@ -81,7 +85,7 @@ void register_core(nb::module_& m) {
|
||||
.def("__getstate__",
|
||||
[](const UniqueId& self) { return nb::bytes(reinterpret_cast<const char*>(self.data()), UniqueIdBytes); });
|
||||
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "CppTcpBootstrap")
|
||||
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
|
||||
.def_static(
|
||||
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
|
||||
@@ -93,7 +97,7 @@ void register_core(nb::module_& m) {
|
||||
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
|
||||
nb::call_guard<nb::gil_scoped_release>(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30);
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
nb::enum_<Transport>(m, "CppTransport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
.value("CudaIpc", Transport::CudaIpc)
|
||||
.value("IB0", Transport::IB0)
|
||||
@@ -106,7 +110,7 @@ void register_core(nb::module_& m) {
|
||||
.value("IB7", Transport::IB7)
|
||||
.value("NumTransports", Transport::NumTransports);
|
||||
|
||||
nb::class_<TransportFlags>(m, "TransportFlags")
|
||||
nb::class_<TransportFlags>(m, "CppTransportFlags")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def("has", &TransportFlags::has, nb::arg("transport"))
|
||||
@@ -130,12 +134,12 @@ void register_core(nb::module_& m) {
|
||||
.def(nb::self == nb::self)
|
||||
.def(nb::self != nb::self);
|
||||
|
||||
nb::enum_<DeviceType>(m, "DeviceType")
|
||||
nb::enum_<DeviceType>(m, "CppDeviceType")
|
||||
.value("Unknown", DeviceType::Unknown)
|
||||
.value("CPU", DeviceType::CPU)
|
||||
.value("GPU", DeviceType::GPU);
|
||||
|
||||
nb::class_<Device>(m, "Device")
|
||||
nb::class_<Device>(m, "CppDevice")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
|
||||
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
|
||||
@@ -147,24 +151,33 @@ void register_core(nb::module_& m) {
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
|
||||
nb::enum_<EndpointConfig::Ib::Mode>(m, "CppIbMode")
|
||||
.value("Default", EndpointConfig::Ib::Mode::Default)
|
||||
.value("Host", EndpointConfig::Ib::Mode::Host)
|
||||
.value("HostNoAtomic", EndpointConfig::Ib::Mode::HostNoAtomic);
|
||||
|
||||
nb::class_<EndpointConfig::Ib>(m, "CppEndpointConfigIb")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<int, int, int, int, int, int, int>(), nb::arg("device_index") = -1,
|
||||
.def(nb::init<int, int, int, int, int, int, int, int, EndpointConfig::Ib::Mode>(), nb::arg("device_index") = -1,
|
||||
nb::arg("port") = EndpointConfig::Ib::DefaultPort,
|
||||
nb::arg("gid_index") = EndpointConfig::Ib::DefaultGidIndex,
|
||||
nb::arg("max_cq_size") = EndpointConfig::Ib::DefaultMaxCqSize,
|
||||
nb::arg("max_cq_poll_num") = EndpointConfig::Ib::DefaultMaxCqPollNum,
|
||||
nb::arg("max_send_wr") = EndpointConfig::Ib::DefaultMaxSendWr,
|
||||
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend)
|
||||
nb::arg("max_recv_wr") = EndpointConfig::Ib::DefaultMaxRecvWr,
|
||||
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend,
|
||||
nb::arg("mode") = EndpointConfig::Ib::Mode::Default)
|
||||
.def_rw("device_index", &EndpointConfig::Ib::deviceIndex)
|
||||
.def_rw("port", &EndpointConfig::Ib::port)
|
||||
.def_rw("gid_index", &EndpointConfig::Ib::gidIndex)
|
||||
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
|
||||
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
|
||||
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
|
||||
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
|
||||
.def_rw("max_recv_wr", &EndpointConfig::Ib::maxRecvWr)
|
||||
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend)
|
||||
.def_rw("mode", &EndpointConfig::Ib::mode);
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
nb::class_<RegisteredMemory>(m, "CppRegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
||||
.def("size", &RegisteredMemory::size)
|
||||
@@ -172,7 +185,7 @@ void register_core(nb::module_& m) {
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
nb::class_<Endpoint>(m, "CppEndpoint")
|
||||
.def("config", &Endpoint::config)
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("device", &Endpoint::device)
|
||||
@@ -180,7 +193,7 @@ void register_core(nb::module_& m) {
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Connection>(m, "Connection")
|
||||
nb::class_<Connection>(m, "CppConnection")
|
||||
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
||||
nb::arg("size"))
|
||||
.def(
|
||||
@@ -197,7 +210,7 @@ void register_core(nb::module_& m) {
|
||||
.def("local_device", &Connection::localDevice)
|
||||
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
|
||||
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
nb::class_<EndpointConfig>(m, "CppEndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
|
||||
@@ -223,12 +236,18 @@ void register_core(nb::module_& m) {
|
||||
.def_prop_rw(
|
||||
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_recv_wr", [](EndpointConfig& self) { return self.ib.maxRecvWr; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxRecvWr = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
|
||||
.def_prop_rw(
|
||||
"ib_mode", [](EndpointConfig& self) { return self.ib.mode; },
|
||||
[](EndpointConfig& self, EndpointConfig::Ib::Mode v) { self.ib.mode = v; })
|
||||
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
nb::class_<Context>(m, "CppContext")
|
||||
.def_static("create", &Context::create)
|
||||
.def(
|
||||
"register_memory",
|
||||
@@ -239,13 +258,13 @@ void register_core(nb::module_& m) {
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
|
||||
nb::class_<SemaphoreStub>(m, "CppSemaphoreStub")
|
||||
.def(nb::init<const Connection&>(), nb::arg("connection"))
|
||||
.def("memory", &SemaphoreStub::memory)
|
||||
.def("serialize", &SemaphoreStub::serialize)
|
||||
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Semaphore>(m, "Semaphore")
|
||||
nb::class_<Semaphore>(m, "CppSemaphore")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("local_stub"), nb::arg("remote_stub"))
|
||||
.def("connection", &Semaphore::connection)
|
||||
@@ -256,7 +275,7 @@ void register_core(nb::module_& m) {
|
||||
def_shared_future<Connection>(m, "Connection");
|
||||
def_shared_future<Semaphore>(m, "Semaphore");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
nb::class_<Communicator>(m, "CppCommunicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
@@ -289,6 +308,9 @@ void register_core(nb::module_& m) {
|
||||
}
|
||||
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
#ifdef MSCCLPP_DISABLE_NB_LEAK_WARNINGS
|
||||
nb::set_leak_warnings(false);
|
||||
#endif
|
||||
register_env(m);
|
||||
register_error(m);
|
||||
register_port_channel(m);
|
||||
@@ -306,4 +328,4 @@ NB_MODULE(_mscclpp, m) {
|
||||
|
||||
// ext
|
||||
register_algorithm_collection_builder(m);
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_env(nb::module_& m) {
|
||||
nb::class_<Env>(m, "Env")
|
||||
nb::class_<Env>(m, "CppEnv")
|
||||
.def_ro("debug", &Env::debug)
|
||||
.def_ro("debug_subsys", &Env::debugSubsys)
|
||||
.def_ro("debug_file", &Env::debugFile)
|
||||
@@ -20,9 +20,11 @@ void register_env(nb::module_& m) {
|
||||
.def_ro("socket_family", &Env::socketFamily)
|
||||
.def_ro("socket_ifname", &Env::socketIfname)
|
||||
.def_ro("comm_id", &Env::commId)
|
||||
.def_ro("execution_plan_dir", &Env::executionPlanDir)
|
||||
.def_ro("ibv_mode", &Env::ibvMode)
|
||||
.def_ro("cache_dir", &Env::cacheDir)
|
||||
.def_ro("npkit_dump_dir", &Env::npkitDumpDir)
|
||||
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream);
|
||||
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream)
|
||||
.def_ro("ib_gid_index", &Env::ibGidIndex);
|
||||
|
||||
m.def("env", &env);
|
||||
}
|
||||
|
||||
@@ -11,18 +11,18 @@ using namespace mscclpp;
|
||||
|
||||
#define REGISTER_EXCEPTION_TRANSLATOR(name_) \
|
||||
nb::register_exception_translator( \
|
||||
[](const std::exception_ptr &p, void *payload) { \
|
||||
[](const std::exception_ptr& p, void* payload) { \
|
||||
try { \
|
||||
std::rethrow_exception(p); \
|
||||
} catch (const name_ &e) { \
|
||||
PyErr_SetObject(reinterpret_cast<PyObject *>(payload), \
|
||||
} catch (const name_& e) { \
|
||||
PyErr_SetObject(reinterpret_cast<PyObject*>(payload), \
|
||||
PyTuple_Pack(2, PyLong_FromLong(long(e.getErrorCode())), PyUnicode_FromString(e.what()))); \
|
||||
} \
|
||||
}, \
|
||||
m.attr(#name_).ptr());
|
||||
|
||||
void register_error(nb::module_ &m) {
|
||||
nb::enum_<ErrorCode>(m, "ErrorCode")
|
||||
void register_error(nb::module_& m) {
|
||||
nb::enum_<ErrorCode>(m, "CppErrorCode")
|
||||
.value("SystemError", ErrorCode::SystemError)
|
||||
.value("InternalError", ErrorCode::InternalError)
|
||||
.value("RemoteError", ErrorCode::RemoteError)
|
||||
|
||||
@@ -15,16 +15,16 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_executor(nb::module_& m) {
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
nb::enum_<PacketType>(m, "CppPacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
|
||||
nb::class_<ExecutionPlan>(m, "CppExecutionPlan")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
|
||||
.def_prop_ro("name", [](const ExecutionPlan& self) -> std::string { return self.name(); })
|
||||
.def_prop_ro("collective", [](const ExecutionPlan& self) -> std::string { return self.collective(); })
|
||||
.def_prop_ro("min_message_size", [](const ExecutionPlan& self) -> size_t { return self.minMessageSize(); })
|
||||
.def_prop_ro("max_message_size", [](const ExecutionPlan& self) -> size_t { return self.maxMessageSize(); });
|
||||
|
||||
nb::class_<Executor>(m, "Executor")
|
||||
nb::class_<Executor>(m, "CppExecutor")
|
||||
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
|
||||
.def(
|
||||
"execute",
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
@@ -15,7 +16,7 @@ using namespace mscclpp;
|
||||
using namespace mscclpp::collective;
|
||||
|
||||
void register_algorithm_collection_builder(nb::module_& m) {
|
||||
nb::class_<AlgorithmCollectionBuilder>(m, "AlgorithmCollectionBuilder")
|
||||
nb::class_<AlgorithmCollectionBuilder>(m, "CppAlgorithmCollectionBuilder")
|
||||
.def_static("get_instance", &AlgorithmCollectionBuilder::getInstance)
|
||||
.def("add_algorithm_builder", &AlgorithmCollectionBuilder::addAlgorithmBuilder, nb::arg("builder"))
|
||||
.def(
|
||||
@@ -29,6 +30,6 @@ void register_algorithm_collection_builder(nb::module_& m) {
|
||||
nb::arg("selector"))
|
||||
.def("build", &AlgorithmCollectionBuilder::build)
|
||||
.def("build_default_algorithms", &AlgorithmCollectionBuilder::buildDefaultAlgorithms, nb::arg("scratch_buffer"),
|
||||
nb::arg("scratch_buffer_size"), nb::arg("rank"))
|
||||
nb::arg("scratch_buffer_size"), nb::arg("flag_buffer"), nb::arg("flag_buffer_size"), nb::arg("rank"))
|
||||
.def_static("reset", &AlgorithmCollectionBuilder::reset);
|
||||
}
|
||||
@@ -9,7 +9,7 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger")
|
||||
nb::class_<ProxyTrigger>(m, "CppProxyTrigger")
|
||||
.def_prop_rw(
|
||||
"fst", [](const ProxyTrigger& self) { return self.fst; },
|
||||
[](ProxyTrigger& self, uint64_t v) { self.fst = v; })
|
||||
@@ -17,7 +17,7 @@ void register_fifo(nb::module_& m) {
|
||||
"snd", [](const ProxyTrigger& self) { return self.snd; },
|
||||
[](ProxyTrigger& self, uint64_t v) { self.snd = v; });
|
||||
|
||||
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
|
||||
nb::class_<FifoDeviceHandle>(m, "CppFifoDeviceHandle")
|
||||
.def_rw("triggers", &FifoDeviceHandle::triggers)
|
||||
.def_rw("tail", &FifoDeviceHandle::tail)
|
||||
.def_rw("head", &FifoDeviceHandle::head)
|
||||
@@ -26,7 +26,7 @@ void register_fifo(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Fifo>(m, "Fifo")
|
||||
nb::class_<Fifo>(m, "CppFifo")
|
||||
.def(nb::init<int>(), nb::arg("size") = DEFAULT_FIFO_SIZE)
|
||||
.def("poll", &Fifo::poll)
|
||||
.def("pop", &Fifo::pop)
|
||||
|
||||
@@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) {
|
||||
return DLDataType{kDLBfloat, 16, 1};
|
||||
} else if (type == "torch.float16") {
|
||||
return DLDataType{kDLFloat, 16, 1};
|
||||
} else if (type == "torch.float8_e4m3fn") {
|
||||
return DLDataType{kDLFloat8_e4m3fn, 8, 1};
|
||||
} else if (type == "torch.float8_e4m3fnuz") {
|
||||
return DLDataType{kDLFloat8_e4m3fnuz, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2") {
|
||||
return DLDataType{kDLFloat8_e5m2, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2fnuz") {
|
||||
return DLDataType{kDLFloat8_e5m2fnuz, 8, 1};
|
||||
} else if (type == "torch.uint8") {
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else if (type == "fp8_e4m3b15") {
|
||||
// No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes.
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else {
|
||||
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
|
||||
}
|
||||
@@ -101,7 +114,7 @@ static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType, std::v
|
||||
void register_gpu_utils(nb::module_& m) {
|
||||
m.def("is_nvls_supported", &isNvlsSupported);
|
||||
|
||||
nb::class_<GpuBuffer<char>>(m, "RawGpuBuffer")
|
||||
nb::class_<GpuBuffer<char>>(m, "CppRawGpuBuffer")
|
||||
.def(nb::init<size_t>(), nb::arg("nelems"))
|
||||
.def("nelems", &GpuBuffer<char>::nelems)
|
||||
.def("bytes", &GpuBuffer<char>::bytes)
|
||||
|
||||
@@ -11,20 +11,20 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
|
||||
nb::class_<BaseMemoryChannel>(m, "CppBaseMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def("device_handle", &BaseMemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "CppBaseMemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &BaseMemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_prop_ro("raw", [](const BaseMemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<MemoryChannel>(m, "MemoryChannel")
|
||||
nb::class_<MemoryChannel>(m, "CppMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(
|
||||
"__init__",
|
||||
@@ -42,7 +42,7 @@ void register_memory_channel(nb::module_& m) {
|
||||
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
|
||||
.def("device_handle", &MemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "CppMemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
void register_npkit(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("npkit", "NPKit functions");
|
||||
void register_npkit(nb::module_& m) {
|
||||
nb::module_ sub_m = m.def_submodule("cpp_npkit", "NPKit functions");
|
||||
sub_m.def("init", &NpKit::Init);
|
||||
sub_m.def("dump", &NpKit::Dump);
|
||||
sub_m.def("shutdown", &NpKit::Shutdown);
|
||||
|
||||
@@ -6,8 +6,8 @@ int getDeviceNumaNode(int cudaDev);
|
||||
void numaBind(int node);
|
||||
}; // namespace mscclpp
|
||||
|
||||
void register_numa(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
|
||||
void register_numa(nb::module_& m) {
|
||||
nb::module_ sub_m = m.def_submodule("cpp_numa", "numa functions");
|
||||
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
|
||||
sub_m.def("numa_bind", &mscclpp::numaBind);
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_port_channel(nb::module_& m) {
|
||||
nb::class_<BaseProxyService>(m, "BaseProxyService")
|
||||
nb::class_<BaseProxyService>(m, "CppBaseProxyService")
|
||||
.def("start_proxy", &BaseProxyService::startProxy, nb::arg("blocking") = false)
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "CppProxyService")
|
||||
.def(nb::init<int>(), nb::arg("fifo_size") = DEFAULT_FIFO_SIZE)
|
||||
.def("start_proxy", &ProxyService::startProxy, nb::arg("blocking") = false)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
@@ -31,13 +31,13 @@ void register_port_channel(nb::module_& m) {
|
||||
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
|
||||
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
|
||||
nb::class_<BasePortChannel>(m, "BasePortChannel")
|
||||
nb::class_<BasePortChannel>(m, "CppBasePortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &BasePortChannel::deviceHandle);
|
||||
|
||||
nb::class_<BasePortChannel::DeviceHandle>(m, "BasePortChannelDeviceHandle")
|
||||
nb::class_<BasePortChannel::DeviceHandle>(m, "CppBasePortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_id_", &BasePortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &BasePortChannel::DeviceHandle::semaphore_)
|
||||
@@ -46,13 +46,13 @@ void register_port_channel(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<PortChannel>(m, "PortChannel")
|
||||
nb::class_<PortChannel>(m, "CppPortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
|
||||
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
|
||||
.def("device_handle", &PortChannel::deviceHandle);
|
||||
|
||||
nb::class_<PortChannel::DeviceHandle>(m, "PortChannelDeviceHandle")
|
||||
nb::class_<PortChannel::DeviceHandle>(m, "CppPortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_id_", &PortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &PortChannel::DeviceHandle::semaphore_)
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "CppHost2DeviceSemaphore");
|
||||
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
@@ -25,7 +25,7 @@ void register_semaphore(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
nb::class_<Host2HostSemaphore>(m, "CppHost2HostSemaphore")
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
@@ -34,7 +34,7 @@ void register_semaphore(nb::module_& m) {
|
||||
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
|
||||
nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "CppMemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
|
||||
@@ -43,7 +43,6 @@ void register_semaphore(nb::module_& m) {
|
||||
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
|
||||
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
|
||||
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
|
||||
@@ -15,11 +15,11 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_nvls(nb::module_& m) {
|
||||
nb::class_<SwitchChannel>(m, "SwitchChannel")
|
||||
nb::class_<SwitchChannel>(m, "CppSwitchChannel")
|
||||
.def("get_device_ptr", [](SwitchChannel* self) { return (uintptr_t)self->getDevicePtr(); })
|
||||
.def("device_handle", &SwitchChannel::deviceHandle);
|
||||
|
||||
nb::class_<SwitchChannel::DeviceHandle>(m, "DeviceHandle")
|
||||
nb::class_<SwitchChannel::DeviceHandle>(m, "CppSwitchChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("device_ptr", &SwitchChannel::DeviceHandle::devicePtr)
|
||||
.def_rw("mc_ptr", &SwitchChannel::DeviceHandle::mcPtr)
|
||||
@@ -28,7 +28,7 @@ void register_nvls(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<NvlsConnection>(m, "NvlsConnection")
|
||||
nb::class_<NvlsConnection>(m, "CppNvlsConnection")
|
||||
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"));
|
||||
|
||||
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("all_ranks"),
|
||||
|
||||
@@ -23,35 +23,37 @@ version = {
|
||||
from ._core import *
|
||||
|
||||
from ._mscclpp import (
|
||||
Device,
|
||||
DeviceType,
|
||||
Communicator,
|
||||
Connection,
|
||||
CppDevice as Device,
|
||||
CppDeviceType as DeviceType,
|
||||
CppCommunicator as Communicator,
|
||||
CppConnection as Connection,
|
||||
connect_nvls_collective,
|
||||
EndpointConfig,
|
||||
Fifo,
|
||||
Semaphore,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
numa,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
PortChannel,
|
||||
MemoryChannel,
|
||||
MemoryDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
TransportFlags,
|
||||
DataType,
|
||||
ErrorCode,
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
PacketType,
|
||||
RawGpuBuffer,
|
||||
ReduceOp,
|
||||
CppEndpointConfig as EndpointConfig,
|
||||
CppEndpointConfigIb as EndpointConfigIb,
|
||||
CppIbMode as IbMode,
|
||||
CppFifo as Fifo,
|
||||
CppSemaphore as Semaphore,
|
||||
CppHost2DeviceSemaphore as Host2DeviceSemaphore,
|
||||
CppHost2HostSemaphore as Host2HostSemaphore,
|
||||
cpp_numa as numa,
|
||||
CppProxyService as ProxyService,
|
||||
CppRegisteredMemory as RegisteredMemory,
|
||||
CppPortChannel as PortChannel,
|
||||
CppMemoryChannel as MemoryChannel,
|
||||
CppMemoryDevice2DeviceSemaphore as MemoryDevice2DeviceSemaphore,
|
||||
CppTcpBootstrap as TcpBootstrap,
|
||||
CppTransport as Transport,
|
||||
CppTransportFlags as TransportFlags,
|
||||
CppDataType as DataType,
|
||||
CppErrorCode as ErrorCode,
|
||||
CppExecutor as Executor,
|
||||
CppExecutionPlan as ExecutionPlan,
|
||||
CppPacketType as PacketType,
|
||||
CppRawGpuBuffer as RawGpuBuffer,
|
||||
CppReduceOp as ReduceOp,
|
||||
env,
|
||||
is_nvls_supported,
|
||||
npkit,
|
||||
cpp_npkit as npkit,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -61,6 +63,8 @@ __all__ = [
|
||||
"Connection",
|
||||
"connect_nvls_collective",
|
||||
"EndpointConfig",
|
||||
"EndpointConfigIb",
|
||||
"IbMode",
|
||||
"ErrorCode",
|
||||
"Fifo",
|
||||
"Semaphore",
|
||||
|
||||
@@ -6,7 +6,7 @@ import shutil
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from mscclpp.language import default_algos as def_algo
|
||||
from mscclpp import default_algos as def_algo
|
||||
from mscclpp.language.collectives import *
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
|
||||
@@ -57,7 +57,7 @@ default_algo_configs = [
|
||||
|
||||
|
||||
def create_default_plans():
|
||||
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp_default")
|
||||
plan_dir = os.path.join(os.environ.get("MSCCLPP_CACHE_DIR", Path.home() / ".cache/mscclpp"), "default")
|
||||
plan_path = Path(plan_dir)
|
||||
if plan_path.exists():
|
||||
shutil.rmtree(plan_path)
|
||||
|
||||
@@ -5,9 +5,3 @@ from .algorithm import *
|
||||
from .comm import *
|
||||
from .compiler import *
|
||||
from .buffer import *
|
||||
|
||||
__all__ = []
|
||||
__all__ += algorithm.__all__
|
||||
__all__ += comm.__all__
|
||||
__all__ += compiler.__all__
|
||||
__all__ += buffer.__all__
|
||||
|
||||
@@ -4,18 +4,22 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Dict
|
||||
from functools import cached_property
|
||||
import cupy as cp
|
||||
|
||||
|
||||
from mscclpp._mscclpp import (
|
||||
Algorithm as _Algorithm,
|
||||
DslAlgorithm as _DslAlgorithm,
|
||||
AlgorithmType as _AlgorithmType,
|
||||
Communicator,
|
||||
CollectiveBufferMode,
|
||||
DataType,
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
ReduceOp,
|
||||
CppAlgorithm,
|
||||
CppDslAlgorithm,
|
||||
CppAlgorithmType,
|
||||
CppCommunicator,
|
||||
CppCollectiveBufferMode,
|
||||
CppDataType,
|
||||
CppExecutor,
|
||||
CppExecutionPlan,
|
||||
CppReduceOp,
|
||||
CppAlgorithmBuilder,
|
||||
CppAlgorithmCollection,
|
||||
cpp_get_flag_buffer,
|
||||
)
|
||||
|
||||
__all__ = ["Algorithm", "AlgorithmBuilder", "AlgorithmCollection"]
|
||||
@@ -45,7 +49,7 @@ class Algorithm:
|
||||
"""
|
||||
|
||||
def __init__(self, world_size: int = 0, n_ranks_per_node: int = 0):
|
||||
self._constraint = _Algorithm.Constraint(world_size, n_ranks_per_node)
|
||||
self._constraint = CppAlgorithm.Constraint(world_size, n_ranks_per_node)
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@@ -58,23 +62,23 @@ class Algorithm:
|
||||
def __init__(
|
||||
self,
|
||||
id: Optional[str] = None,
|
||||
execution_plan: Optional[ExecutionPlan] = None,
|
||||
native_handle: Optional[_Algorithm] = None,
|
||||
execution_plan: Optional[CppExecutionPlan] = None,
|
||||
native_handle: Optional[CppAlgorithm] = None,
|
||||
tags: Optional[Dict[str, int]] = None,
|
||||
constraint: Optional[Constraint] = None,
|
||||
):
|
||||
if execution_plan is not None:
|
||||
self._algorithm = _DslAlgorithm(
|
||||
self._algorithm = CppDslAlgorithm(
|
||||
id,
|
||||
execution_plan,
|
||||
tags=tags if tags is not None else {},
|
||||
constraint=constraint._constraint if constraint is not None else _Algorithm.Constraint(),
|
||||
constraint=constraint._constraint if constraint is not None else CppAlgorithm.Constraint(),
|
||||
)
|
||||
elif native_handle is not None:
|
||||
self._algorithm = native_handle
|
||||
|
||||
@classmethod
|
||||
def create_from_native_handle(cls, handle: _Algorithm):
|
||||
def create_from_native_handle(cls, handle: CppAlgorithm):
|
||||
"""Create an Algorithm instance from a native C++ algorithm handle.
|
||||
|
||||
Args:
|
||||
@@ -97,7 +101,7 @@ class Algorithm:
|
||||
Returns:
|
||||
A new Algorithm instance wrapping the algorithm from the capsule.
|
||||
"""
|
||||
handle = _Algorithm.from_native_capsule(obj)
|
||||
handle = CppAlgorithm.from_native_capsule(obj)
|
||||
return cls(native_handle=handle)
|
||||
|
||||
@cached_property
|
||||
@@ -110,18 +114,31 @@ class Algorithm:
|
||||
"""The collective operation this algorithm implements (e.g., "allreduce", "allgather")."""
|
||||
return self._algorithm.collective
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def message_size_range(self) -> Tuple[int, int]:
|
||||
"""The valid message size range (min_size, max_size) in bytes."""
|
||||
return (self._algorithm.message_range[0], self._algorithm.message_range[1])
|
||||
|
||||
def set_message_size_range(self, min_message_size: int, max_message_size: int):
|
||||
"""Set the valid message size range in bytes.
|
||||
|
||||
Args:
|
||||
min_message_size: Minimum supported message size in bytes.
|
||||
max_message_size: Maximum supported message size in bytes.
|
||||
|
||||
Only supported for native algorithms. Raises TypeError for DSL algorithms.
|
||||
"""
|
||||
if self.is_dsl_algorithm():
|
||||
raise TypeError("set_message_size_range is only supported for native algorithms")
|
||||
self._algorithm.set_message_size_range(min_message_size, max_message_size)
|
||||
|
||||
@cached_property
|
||||
def tags(self) -> Dict[str, int]:
|
||||
"""Dictionary of tag names to tag values for algorithm selection hints."""
|
||||
return self._algorithm.tags
|
||||
|
||||
@cached_property
|
||||
def buffer_mode(self) -> CollectiveBufferMode:
|
||||
def buffer_mode(self) -> CppCollectiveBufferMode:
|
||||
"""The buffer mode supported by this algorithm (IN_PLACE, OUT_OF_PLACE, or ANY)."""
|
||||
return self._algorithm.buffer_mode
|
||||
|
||||
@@ -131,7 +148,7 @@ class Algorithm:
|
||||
Returns:
|
||||
True if this algorithm is defined using DSL/execution plan, False otherwise.
|
||||
"""
|
||||
if self._algorithm.type == _AlgorithmType.DSL:
|
||||
if self._algorithm.type == CppAlgorithmType.DSL:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -141,24 +158,26 @@ class Algorithm:
|
||||
Returns:
|
||||
True if this algorithm is implemented natively, False otherwise.
|
||||
"""
|
||||
if self._algorithm.type == _AlgorithmType.NATIVE:
|
||||
if self._algorithm.type == CppAlgorithmType.NATIVE:
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
comm: Communicator,
|
||||
comm: CppCommunicator,
|
||||
input_buffer: int,
|
||||
output_buffer: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
dtype: DataType,
|
||||
op: ReduceOp = ReduceOp.NOP,
|
||||
dtype: CppDataType,
|
||||
op: CppReduceOp = CppReduceOp.NOP,
|
||||
stream: int = 0,
|
||||
executor: Optional[Executor] = None,
|
||||
executor: Optional[CppExecutor] = None,
|
||||
nblocks=0,
|
||||
nthreads_per_block=0,
|
||||
symmetric_memory: bool = False,
|
||||
extras: Optional[Dict[str, int]] = None,
|
||||
accum_dtype: Optional[CppDataType] = None,
|
||||
) -> int:
|
||||
"""Execute the collective algorithm.
|
||||
|
||||
@@ -174,11 +193,16 @@ class Algorithm:
|
||||
executor: The executor for DSL algorithms (required for DSL, optional for native).
|
||||
nblocks: Number of CUDA blocks (0 for auto-selection).
|
||||
nthreads_per_block: Number of threads per block (0 for auto-selection).
|
||||
symmetric_memory: Whether to use symmetric memory optimization (default: False).
|
||||
extras: Additional algorithm-specific parameters.
|
||||
accum_dtype: Data type for accumulation during reduction. If None, defaults to
|
||||
the same as dtype. Use DataType.float32 for high-precision FP8 accumulation.
|
||||
|
||||
Returns:
|
||||
The result code (0 for success).
|
||||
"""
|
||||
merged_extras = dict(extras) if extras is not None else {}
|
||||
accum_dtype = accum_dtype if accum_dtype is not None else dtype
|
||||
return self._algorithm.execute(
|
||||
comm,
|
||||
int(input_buffer),
|
||||
@@ -191,12 +215,18 @@ class Algorithm:
|
||||
executor,
|
||||
nblocks,
|
||||
nthreads_per_block,
|
||||
extras if extras is not None else {},
|
||||
symmetric_memory,
|
||||
merged_extras,
|
||||
int(accum_dtype),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the internal state of the algorithm, if applicable."""
|
||||
self._algorithm.reset()
|
||||
|
||||
|
||||
class AlgorithmBuilder:
|
||||
def __init__(self, algorithm_builder: _AlgorithmBuilder):
|
||||
def __init__(self, algorithm_builder: CppAlgorithmBuilder):
|
||||
self._algorithm_builder = algorithm_builder
|
||||
|
||||
def build(self) -> Algorithm:
|
||||
@@ -204,7 +234,7 @@ class AlgorithmBuilder:
|
||||
|
||||
|
||||
class AlgorithmCollection:
|
||||
def __init__(self, native_collection: _AlgorithmCollection):
|
||||
def __init__(self, native_collection: CppAlgorithmCollection):
|
||||
self._native_collection = native_collection
|
||||
self._algorithms = [Algorithm.create_from_native_handle(algo) for algo in self._native_collection.to_list()]
|
||||
|
||||
@@ -228,3 +258,24 @@ class AlgorithmCollection:
|
||||
"""Register an algorithm for a collective operation."""
|
||||
self._native_collection.register_algorithm(collective, algo_name, algorithm._algorithm)
|
||||
self._algorithms.append(algorithm)
|
||||
|
||||
|
||||
_flag_buffer_cache = None
|
||||
|
||||
|
||||
def get_flag_buffer() -> cp.ndarray:
|
||||
"""Get the default flag buffer for algorithm selection.
|
||||
|
||||
This buffer is used internally by default algorithms to store selection flags.
|
||||
It is allocated as a shared GPU buffer and can be accessed from Python.
|
||||
The result is cached so all callers share the same buffer.
|
||||
|
||||
Returns:
|
||||
A CuPy array representing the flag buffer on the GPU.
|
||||
"""
|
||||
global _flag_buffer_cache
|
||||
if _flag_buffer_cache is None:
|
||||
buffer_ptr, buffer_size, owner = cpp_get_flag_buffer()
|
||||
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer_ptr, buffer_size, owner), 0)
|
||||
_flag_buffer_cache = cp.ndarray((buffer_size // 4,), dtype=cp.uint32, memptr=memptr)
|
||||
return _flag_buffer_cache
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Union, Tuple
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
from mscclpp._mscclpp import RawGpuBuffer
|
||||
from mscclpp._mscclpp import CppRawGpuBuffer
|
||||
|
||||
__all__ = ["GpuBuffer"]
|
||||
|
||||
@@ -25,6 +25,6 @@ class GpuBuffer(cp.ndarray):
|
||||
if any(s <= 0 for s in shape):
|
||||
raise ValueError("Shape must be positive.")
|
||||
# Create the buffer
|
||||
buffer = RawGpuBuffer(np.prod(shape) * np.dtype(dtype).itemsize)
|
||||
buffer = CppRawGpuBuffer(np.prod(shape) * np.dtype(dtype).itemsize)
|
||||
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer.data(), buffer.bytes(), buffer), 0)
|
||||
return cp.ndarray(shape, dtype=dtype, strides=strides, order=order, memptr=memptr)
|
||||
|
||||
@@ -6,21 +6,21 @@ from typing import Type
|
||||
|
||||
import cupy as cp
|
||||
from mscclpp._mscclpp import (
|
||||
Communicator,
|
||||
Connection,
|
||||
CppCommunicator,
|
||||
CppConnection,
|
||||
connect_nvls_collective,
|
||||
EndpointConfig,
|
||||
Semaphore,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
PortChannel,
|
||||
MemoryChannel,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
TransportFlags,
|
||||
CppEndpointConfig,
|
||||
CppSemaphore,
|
||||
CppProxyService,
|
||||
CppRegisteredMemory,
|
||||
CppPortChannel,
|
||||
CppMemoryChannel,
|
||||
CppTcpBootstrap,
|
||||
CppTransport,
|
||||
CppTransportFlags,
|
||||
)
|
||||
import mpi4py
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
from mscclpp.utils import is_torch_tensor
|
||||
|
||||
@@ -29,27 +29,47 @@ __all__ = ["CommGroup"]
|
||||
|
||||
class CommGroup:
|
||||
def __init__(
|
||||
self, mpi_comm: mpi4py.MPI.Comm = None, interfaceIpPortTrio: str = "", rank: int = None, size: int = None
|
||||
self,
|
||||
mpi_comm: "mpi4py.MPI.Comm" = None,
|
||||
torch_group: "dist.ProcessGroup" = None,
|
||||
interfaceIpPortTrio: str = "",
|
||||
rank: int = None,
|
||||
size: int = None,
|
||||
):
|
||||
if interfaceIpPortTrio == "":
|
||||
self.bootstrap = TcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
|
||||
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
|
||||
uniq_id = None
|
||||
if mpi_comm.rank == 0:
|
||||
# similar to NCCL's unique id
|
||||
rank, size = (
|
||||
(mpi_comm.Get_rank(), mpi_comm.Get_size())
|
||||
if mpi_comm is not None
|
||||
else (torch_group.rank(), torch_group.size())
|
||||
)
|
||||
self.bootstrap = CppTcpBootstrap.create(rank, size)
|
||||
if rank == 0:
|
||||
uniq_id = self.bootstrap.create_unique_id()
|
||||
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
|
||||
if mpi_comm is not None:
|
||||
import mpi4py
|
||||
|
||||
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
|
||||
else:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
if rank == 0:
|
||||
uniq_id_global = uniq_id
|
||||
pickled_data = pickle.dumps(uniq_id)
|
||||
data_tensor = torch.frombuffer(bytearray(pickled_data), dtype=torch.uint8).clone()
|
||||
else:
|
||||
data_tensor = torch.zeros(256, dtype=torch.uint8)
|
||||
dist.broadcast(data_tensor, src=0, group=torch_group)
|
||||
uniq_id_global = pickle.loads(data_tensor.numpy().tobytes())
|
||||
self.bootstrap.initialize(uniq_id_global)
|
||||
elif mpi_comm:
|
||||
# use this instead
|
||||
self.bootstrap = TcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
|
||||
self.bootstrap.initialize(interfaceIpPortTrio)
|
||||
elif not interfaceIpPortTrio == "":
|
||||
assert rank >= 0 and size >= 1
|
||||
self.bootstrap = TcpBootstrap.create(rank, size)
|
||||
self.bootstrap = CppTcpBootstrap.create(rank, size)
|
||||
self.bootstrap.initialize(interfaceIpPortTrio)
|
||||
else:
|
||||
raise RuntimeError("Either the interface or mpi_group need to be specified")
|
||||
self.communicator = Communicator(self.bootstrap)
|
||||
self.communicator = CppCommunicator(self.bootstrap)
|
||||
self.my_rank = self.bootstrap.get_rank()
|
||||
self.nranks = self.bootstrap.get_n_ranks()
|
||||
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
|
||||
@@ -63,43 +83,43 @@ class CommGroup:
|
||||
def recv(self, tensor: np.ndarray, peer: int, tag: int):
|
||||
self.bootstrap.recv(tensor.ctypes.data, tensor.size * tensor.itemsize, peer, tag)
|
||||
|
||||
def my_ib_device(self, local_rank: int) -> Transport:
|
||||
def my_ib_device(self, local_rank: int) -> CppTransport:
|
||||
if local_rank == 0:
|
||||
return Transport.IB0
|
||||
return CppTransport.IB0
|
||||
if local_rank == 1:
|
||||
return Transport.IB1
|
||||
return CppTransport.IB1
|
||||
if local_rank == 2:
|
||||
return Transport.IB2
|
||||
return CppTransport.IB2
|
||||
if local_rank == 3:
|
||||
return Transport.IB3
|
||||
return CppTransport.IB3
|
||||
if local_rank == 4:
|
||||
return Transport.IB4
|
||||
return CppTransport.IB4
|
||||
if local_rank == 5:
|
||||
return Transport.IB5
|
||||
return CppTransport.IB5
|
||||
if local_rank == 6:
|
||||
return Transport.IB6
|
||||
return CppTransport.IB6
|
||||
if local_rank == 7:
|
||||
return Transport.IB7
|
||||
return CppTransport.IB7
|
||||
else:
|
||||
assert False # only 8 IBs are supported
|
||||
|
||||
def make_connection(
|
||||
self,
|
||||
all_ranks: list[int],
|
||||
endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport],
|
||||
endpoints: CppEndpointConfig | CppTransport | dict[int, CppEndpointConfig] | dict[int, CppTransport],
|
||||
use_switch: bool = False,
|
||||
) -> dict[int, Connection]:
|
||||
if type(endpoints) is Transport:
|
||||
endpoints = EndpointConfig(endpoints)
|
||||
) -> dict[int, CppConnection]:
|
||||
if type(endpoints) is CppTransport:
|
||||
endpoints = CppEndpointConfig(endpoints)
|
||||
elif type(endpoints) is dict:
|
||||
endpoints = {k: EndpointConfig(v) if type(v) is Transport else v for k, v in endpoints.items()}
|
||||
endpoints = {k: CppEndpointConfig(v) if type(v) is CppTransport else v for k, v in endpoints.items()}
|
||||
connections = {}
|
||||
for rank in all_ranks:
|
||||
if type(endpoints) is dict:
|
||||
endpoint = endpoints[rank]
|
||||
else:
|
||||
endpoint = endpoints
|
||||
if endpoint.transport == Transport.CudaIpc and use_switch:
|
||||
if endpoint.transport == CppTransport.CudaIpc and use_switch:
|
||||
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
|
||||
else:
|
||||
connections[rank] = self.communicator.connect(endpoint, rank)
|
||||
@@ -107,8 +127,8 @@ class CommGroup:
|
||||
return connections
|
||||
|
||||
def register_tensor_with_connections(
|
||||
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
|
||||
) -> dict[int, RegisteredMemory]:
|
||||
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, CppConnection]
|
||||
) -> dict[int, CppRegisteredMemory]:
|
||||
local_reg_memory = self.register_local_memory(tensor, connections)
|
||||
all_registered_memories = {}
|
||||
all_registered_memories[self.my_rank] = local_reg_memory
|
||||
@@ -121,8 +141,8 @@ class CommGroup:
|
||||
return all_registered_memories
|
||||
|
||||
def _register_memory_with_connections(
|
||||
self, memory: RegisteredMemory, connections: dict[int, Connection]
|
||||
) -> dict[int, RegisteredMemory]:
|
||||
self, memory: CppRegisteredMemory, connections: dict[int, CppConnection]
|
||||
) -> dict[int, CppRegisteredMemory]:
|
||||
all_registered_memories = {}
|
||||
all_registered_memories[self.my_rank] = memory
|
||||
future_memories = {}
|
||||
@@ -133,18 +153,20 @@ class CommGroup:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
|
||||
def make_semaphores(self, connections: dict[int, Connection]) -> dict[int, Semaphore]:
|
||||
def make_semaphores(self, connections: dict[int, CppConnection]) -> dict[int, CppSemaphore]:
|
||||
future_semaphores = {}
|
||||
for rank in connections:
|
||||
future_semaphores[rank] = self.communicator.build_semaphore(connections[rank], rank)
|
||||
return {rank: future.get() for rank, future in future_semaphores.items()}
|
||||
|
||||
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
|
||||
def make_memory_channels(
|
||||
self, tensor: cp.ndarray, connections: dict[int, CppConnection]
|
||||
) -> dict[int, CppMemoryChannel]:
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
channels = {}
|
||||
for rank in connections:
|
||||
channels[rank] = MemoryChannel(
|
||||
channels[rank] = CppMemoryChannel(
|
||||
semaphores[rank], registered_memories[rank], registered_memories[self.my_rank]
|
||||
)
|
||||
return channels
|
||||
@@ -152,9 +174,9 @@ class CommGroup:
|
||||
def make_memory_channels_with_scratch(
|
||||
self,
|
||||
tensor: cp.ndarray,
|
||||
registeredScratchBuffer: RegisteredMemory,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, MemoryChannel]:
|
||||
registeredScratchBuffer: CppRegisteredMemory,
|
||||
connections: dict[int, CppConnection],
|
||||
) -> dict[int, CppMemoryChannel]:
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
|
||||
channels = {}
|
||||
@@ -162,17 +184,17 @@ class CommGroup:
|
||||
tensor_size = (
|
||||
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
|
||||
)
|
||||
local_registered_memory = self.communicator.register_memory(tensor_data_ptr, tensor_size, TransportFlags())
|
||||
local_registered_memory = self.communicator.register_memory(tensor_data_ptr, tensor_size, CppTransportFlags())
|
||||
scratch_data_ptr = registeredScratchBuffer.data()
|
||||
for rank in connections:
|
||||
channels[rank] = MemoryChannel(
|
||||
channels[rank] = CppMemoryChannel(
|
||||
semaphores[rank], registered_memories[rank], local_registered_memory, scratch_data_ptr
|
||||
)
|
||||
return channels
|
||||
|
||||
def make_port_channels(
|
||||
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
|
||||
) -> dict[int, PortChannel]:
|
||||
self, proxy_service: CppProxyService, tensor: cp.ndarray, connections: dict[int, CppConnection]
|
||||
) -> dict[int, CppPortChannel]:
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
memory_ids = {}
|
||||
@@ -188,12 +210,12 @@ class CommGroup:
|
||||
|
||||
def make_port_channels_with_scratch(
|
||||
self,
|
||||
proxy_service: ProxyService,
|
||||
proxy_service: CppProxyService,
|
||||
tensor: cp.ndarray,
|
||||
registeredScratchBuffer: RegisteredMemory,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, PortChannel]:
|
||||
transport_flags = TransportFlags()
|
||||
registeredScratchBuffer: CppRegisteredMemory,
|
||||
connections: dict[int, CppConnection],
|
||||
) -> dict[int, CppPortChannel]:
|
||||
transport_flags = CppTransportFlags()
|
||||
for rank in connections:
|
||||
transport_flags |= connections[rank].transport()
|
||||
data_ptr = (
|
||||
@@ -223,8 +245,8 @@ class CommGroup:
|
||||
return channels
|
||||
|
||||
def register_semaphore_with_proxy(
|
||||
self, proxy_service: ProxyService, connections: dict[int, Connection]
|
||||
) -> dict[int, PortChannel]:
|
||||
self, proxy_service: CppProxyService, connections: dict[int, CppConnection]
|
||||
) -> dict[int, CppPortChannel]:
|
||||
semaphores = self.make_semaphores(connections)
|
||||
semaphore_ids = {}
|
||||
for rank in semaphores:
|
||||
@@ -235,7 +257,7 @@ class CommGroup:
|
||||
return channels
|
||||
|
||||
def register_memory_with_proxy(
|
||||
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
|
||||
self, proxy_service: CppProxyService, tensor: cp.ndarray, connections: dict[int, CppConnection]
|
||||
) -> dict[int, int]:
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
memory_ids = {}
|
||||
@@ -243,8 +265,8 @@ class CommGroup:
|
||||
memory_ids[rank] = proxy_service.add_memory(registered_memories[rank])
|
||||
return memory_ids
|
||||
|
||||
def register_local_memory(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> RegisteredMemory:
|
||||
transport_flags = TransportFlags()
|
||||
def register_local_memory(self, tensor: cp.ndarray, connections: dict[int, CppConnection]) -> CppRegisteredMemory:
|
||||
transport_flags = CppTransportFlags()
|
||||
for rank in connections:
|
||||
transport_flags |= connections[rank].transport()
|
||||
data_ptr = (
|
||||
|
||||
@@ -26,9 +26,7 @@ from mscclpp.language.program import CollectiveProgram
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
from mscclpp.utils import get_device_arch
|
||||
|
||||
from mscclpp._mscclpp import (
|
||||
ExecutionPlan,
|
||||
)
|
||||
from mscclpp._mscclpp import CppExecutionPlan, env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -51,7 +49,7 @@ class DslCompiler:
|
||||
into execution plans that can be run on GPUs. The compiled plans are cached
|
||||
to disk for reuse.
|
||||
|
||||
The cache location can be configured via the `MSCCLPP_EXECUTION_PLAN_DIR`
|
||||
The cache location can be configured via the `MSCCLPP_CACHE_DIR`
|
||||
environment variable (defaults to `~/.cache/mscclpp`).
|
||||
|
||||
Example:
|
||||
@@ -138,7 +136,7 @@ class DslCompiler:
|
||||
)
|
||||
).hexdigest()
|
||||
|
||||
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp")
|
||||
plan_dir = Path(env().cache_dir)
|
||||
os.makedirs(plan_dir, exist_ok=True)
|
||||
filename = f"{plan_id}.json"
|
||||
plan_path = os.path.join(plan_dir, filename)
|
||||
@@ -157,7 +155,7 @@ class DslCompiler:
|
||||
os.remove(tmp_path)
|
||||
except Exception:
|
||||
Path(plan_path).unlink(missing_ok=True)
|
||||
execution_plan = ExecutionPlan(plan_path, rank)
|
||||
execution_plan = CppExecutionPlan(plan_path, rank)
|
||||
return Algorithm(
|
||||
id=plan_id,
|
||||
execution_plan=execution_plan,
|
||||
@@ -179,8 +177,8 @@ class NativeCodeCompiler:
|
||||
based on the runtime environment. Compiled modules are cached to avoid
|
||||
recompilation.
|
||||
|
||||
The cache location can be configured via the `MSCCLPP_NATIVE_CACHE_DIR`
|
||||
environment variable (defaults to `~/.cache/mscclpp/native`).
|
||||
The cache location can be configured via the `MSCCLPP_CACHE_DIR`
|
||||
environment variable (defaults to `~/.cache/mscclpp`).
|
||||
|
||||
Attributes:
|
||||
_is_hip: True if running on AMD/ROCm, False for NVIDIA/CUDA.
|
||||
@@ -226,8 +224,7 @@ class NativeCodeCompiler:
|
||||
"-L" + os.path.join(self._lib_home, "lib"),
|
||||
"-lmscclpp",
|
||||
]
|
||||
cache_root = os.environ.get("MSCCLPP_NATIVE_CACHE_DIR", Path.home() / ".cache/mscclpp/native")
|
||||
self._cache_dir = Path(cache_root)
|
||||
self._cache_dir = Path(env().cache_dir) / "native"
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_compiler(self) -> str:
|
||||
@@ -283,7 +280,7 @@ class NativeCodeCompiler:
|
||||
Note:
|
||||
- The source file should include pybind11 bindings to expose functions.
|
||||
- MSCCLPP headers are automatically included in the compilation.
|
||||
- The module is cached in `MSCCLPP_NATIVE_CACHE_DIR` (default: ~/.cache/mscclpp/native).
|
||||
- The module is cached in `MSCCLPP_CACHE_DIR` (default: ~/.cache/mscclpp).
|
||||
- File locking is used to prevent race conditions during parallel compilation.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Union
|
||||
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection
|
||||
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_flag_buffer
|
||||
import atexit
|
||||
|
||||
from mscclpp._mscclpp import (
|
||||
AlgorithmCollectionBuilder as _AlgorithmCollectionBuilder,
|
||||
)
|
||||
from mscclpp._mscclpp import CppAlgorithmCollectionBuilder
|
||||
|
||||
__all__ = ["AlgorithmCollectionBuilder"]
|
||||
|
||||
@@ -24,13 +22,14 @@ class AlgorithmCollectionBuilder:
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
if cls._instance is not None:
|
||||
_AlgorithmCollectionBuilder.reset()
|
||||
CppAlgorithmCollectionBuilder.reset()
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._builder = _AlgorithmCollectionBuilder.get_instance()
|
||||
self._builder = CppAlgorithmCollectionBuilder.get_instance()
|
||||
self._initialized = True
|
||||
self._flag_buffer = None
|
||||
|
||||
def add_algorithm_builder(self, algorithm_builder: Union[AlgorithmBuilder, Algorithm]):
|
||||
if isinstance(algorithm_builder, AlgorithmBuilder):
|
||||
@@ -52,8 +51,17 @@ class AlgorithmCollectionBuilder:
|
||||
collection = self._builder.build()
|
||||
return AlgorithmCollection(collection)
|
||||
|
||||
def build_default_algorithms(self, scratch_buffer: int, scratch_buffer_size: int, rank: int) -> AlgorithmCollection:
|
||||
native_collection = self._builder.build_default_algorithms(int(scratch_buffer), scratch_buffer_size, rank)
|
||||
def build_default_algorithms(
|
||||
self,
|
||||
scratch_buffer: int,
|
||||
scratch_buffer_size: int,
|
||||
rank: int,
|
||||
) -> AlgorithmCollection:
|
||||
if self._flag_buffer is None:
|
||||
self._flag_buffer = get_flag_buffer()
|
||||
native_collection = self._builder.build_default_algorithms(
|
||||
int(scratch_buffer), scratch_buffer_size, self._flag_buffer.data.ptr, self._flag_buffer.nbytes, rank
|
||||
)
|
||||
return AlgorithmCollection(native_collection)
|
||||
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class MemoryChannel:
|
||||
|
||||
for tb_id in tb_list:
|
||||
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
tb_channel_ids = get_program().setup_channel(tb_id, self)
|
||||
op = GetOperation(
|
||||
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
|
||||
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
|
||||
|
||||
@@ -534,6 +534,7 @@ class PutOperation(BaseOperation):
|
||||
self.dst_buff = dst_buff
|
||||
self.channel_ids = channel_ids
|
||||
self.channel_type = channel_type
|
||||
self.from_packet = from_packet
|
||||
self.to_packet = to_packet
|
||||
self.with_signal = with_signal
|
||||
self.with_signal_and_flush = with_signal_and_flush
|
||||
@@ -579,6 +580,25 @@ class PutOperation(BaseOperation):
|
||||
with_signal=self.with_signal,
|
||||
with_signal_and_flush=self.with_signal_and_flush,
|
||||
)
|
||||
elif (
|
||||
isinstance(other, PutOperation)
|
||||
and self.name == Instruction.read_put_packet
|
||||
and self.name == other.name
|
||||
and self.src_buff == other.src_buff
|
||||
and self.channel_type == other.channel_type
|
||||
and self.tbg_info == other.tbg_info
|
||||
):
|
||||
fused_operation = PutOperation(
|
||||
src_buff=self.src_buff,
|
||||
dst_buff=self.dst_buff + other.dst_buff,
|
||||
channel_ids=self.channel_ids + other.channel_ids,
|
||||
channel_type=self.channel_type,
|
||||
tbg_info=self.tbg_info,
|
||||
from_packet=self.from_packet,
|
||||
to_packet=self.to_packet,
|
||||
with_signal=self.with_signal,
|
||||
with_signal_and_flush=self.with_signal_and_flush,
|
||||
)
|
||||
|
||||
return fused_operation
|
||||
|
||||
@@ -725,7 +745,7 @@ class ReduceOperation(BaseOperation):
|
||||
remote_dst_buff=self.remote_dst_buff + other.dst_buff,
|
||||
channel_ids=self.channel_ids,
|
||||
put_channel_ids=self.put_channel_ids + other.channel_ids,
|
||||
channel_type=self.channel_type,
|
||||
channel_type=other.channel_type,
|
||||
reduce_operation=self.reduce_operation,
|
||||
tbg_info=self.tbg_info,
|
||||
packet=self.packet,
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language.channel import *
|
||||
from mscclpp.language.rank import *
|
||||
from mscclpp.language.general import *
|
||||
from mscclpp.language.program import *
|
||||
from mscclpp.language.collectives import *
|
||||
|
||||
|
||||
def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size):
|
||||
chunksperloop = 1
|
||||
collective = AllGather(gpu_size, chunksperloop, True)
|
||||
with CollectiveProgram(
|
||||
name,
|
||||
collective,
|
||||
gpu_size,
|
||||
protocol="LL",
|
||||
num_threads_per_block=num_threads_per_block,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=min_message_size,
|
||||
max_message_size=max_message_size,
|
||||
):
|
||||
# Creating Scratch Buffers
|
||||
scratch_buffer = []
|
||||
for gpu in range(gpu_size):
|
||||
scratch_buffer.append(Buffer(gpu, 2 * gpu_size))
|
||||
|
||||
# Copying it to scratch buffer
|
||||
for gpu in range(gpu_size):
|
||||
rank = Rank(gpu)
|
||||
scratch_offset = gpu_size
|
||||
input_buffer = rank.get_input_buffer()
|
||||
rank.copy_packets(
|
||||
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1], input_buffer[0:1], tb=0
|
||||
)
|
||||
|
||||
# Putting packets in the remote scratch buffer
|
||||
for gpu in range(gpu_size):
|
||||
rank = Rank(gpu)
|
||||
output_buffer = rank.get_output_buffer()
|
||||
for peer in range(1, gpu_size):
|
||||
dst_rank = (gpu + peer) % gpu_size
|
||||
ch = MemoryChannel(dst_rank, gpu)
|
||||
tb = 0
|
||||
ch.read_put_packets(
|
||||
scratch_buffer[dst_rank][gpu : gpu + 1],
|
||||
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1],
|
||||
tb,
|
||||
)
|
||||
|
||||
# Copying packets from local scratch buffer to local buffer
|
||||
for gpu in range(gpu_size):
|
||||
rank = Rank(gpu)
|
||||
output_buffer = rank.get_output_buffer()
|
||||
for peer in range(1, gpu_size):
|
||||
dst_rank = (gpu + peer) % gpu_size
|
||||
rank.unpack_packets(
|
||||
output_buffer[dst_rank : dst_rank + 1],
|
||||
scratch_buffer[gpu][dst_rank : dst_rank + 1],
|
||||
tb=0,
|
||||
)
|
||||
|
||||
print(JSON())
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--name", type=str, help="name of the program")
|
||||
parser.add_argument("--num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
|
||||
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
|
||||
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allgather_example(args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size)
|
||||
@@ -11,7 +11,7 @@ from typing import Any, Type, Union
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
|
||||
from mscclpp._mscclpp import DataType
|
||||
from mscclpp._mscclpp import CppDataType as DataType
|
||||
|
||||
try:
|
||||
import torch
|
||||
@@ -192,5 +192,13 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType:
|
||||
return DataType.int32
|
||||
elif dtype == torch.bfloat16:
|
||||
return DataType.bfloat16
|
||||
# Hardware supports either OCP format or FNUZ format for float8.
|
||||
# Mapping both to the same MSCClPP data type.
|
||||
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
|
||||
return DataType.float8_e5m2
|
||||
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
|
||||
return DataType.float8_e4m3
|
||||
elif dtype == torch.uint8:
|
||||
return DataType.uint8
|
||||
else:
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
@@ -6,4 +6,5 @@ pytest
|
||||
numpy
|
||||
matplotlib
|
||||
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
|
||||
blake3
|
||||
blake3
|
||||
pybind11
|
||||
@@ -0,0 +1,10 @@
|
||||
mpi4py
|
||||
cupy
|
||||
prettytable
|
||||
netifaces
|
||||
pytest
|
||||
numpy
|
||||
matplotlib
|
||||
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
|
||||
blake3
|
||||
pybind11
|
||||
@@ -63,10 +63,13 @@ class MyProxyService {
|
||||
};
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
#ifdef MSCCLPP_DISABLE_NB_LEAK_WARNINGS
|
||||
nb::set_leak_warnings(false);
|
||||
#endif
|
||||
nb::class_<MyProxyService>(m, "MyProxyService")
|
||||
.def(nb::init<int, int, int, nb::list, nb::list>(), nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"),
|
||||
nb::arg("reg_mem_list"), nb::arg("sem_list"))
|
||||
.def("fifo_device_handle", &MyProxyService::fifoDeviceHandle)
|
||||
.def("start", &MyProxyService::start)
|
||||
.def("stop", &MyProxyService::stop);
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ from mscclpp import (
|
||||
env,
|
||||
)
|
||||
from mscclpp import CommGroup, GpuBuffer
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
import os
|
||||
import struct
|
||||
|
||||
|
||||
397
python/test/test_fp8_accum.py
Normal file
397
python/test/test_fp8_accum.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Correctness test for FP8 allreduce with different accumulation types.
|
||||
#
|
||||
# Verifies that FP8 allreduce with higher-precision accumulation produces
|
||||
# results at least as accurate as native FP8 accumulation, by comparing
|
||||
# against a float32 reference.
|
||||
#
|
||||
# Usage:
|
||||
# mpirun -np 8 pytest python/test/test_fp8_accum.py -v
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mscclpp import CommGroup, GpuBuffer, DataType, ReduceOp, is_nvls_supported
|
||||
from mscclpp.ext import AlgorithmCollectionBuilder
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
|
||||
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
|
||||
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
|
||||
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
|
||||
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
|
||||
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3fn_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3FN bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-7) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 7).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-6) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 & mant==7
|
||||
nan_mask = (exp == 15) & (mant == 7)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3fn(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3FN bit patterns.
|
||||
|
||||
Uses a lookup-table approach: precompute all 128 positive E4M3FN values,
|
||||
then find nearest match per element via chunked broadcast comparison.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3FN values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3fn_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -448.0, 448.0)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
# (chunk_size, 128) difference matrix
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3b15_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3B15 bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-15) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 15).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-14) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 or negative zero (0x80)
|
||||
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3b15(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
|
||||
|
||||
Same lookup-table approach as float_to_e4m3fn.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -0.9375, 0.9375)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
# (chunk_size, 128) difference matrix
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def setup_algorithms(mpi_group):
|
||||
"""Build default algorithms and return (comm_group, algo_map, scratch_buf)."""
|
||||
comm_group = CommGroup(mpi_group.comm)
|
||||
scratch = GpuBuffer(1 << 27, dtype=cp.uint8) # 128 MB
|
||||
AlgorithmCollectionBuilder.reset()
|
||||
builder = AlgorithmCollectionBuilder()
|
||||
algorithms = builder.build_default_algorithms(
|
||||
scratch_buffer=scratch.data.ptr,
|
||||
scratch_buffer_size=scratch.nbytes,
|
||||
rank=comm_group.my_rank,
|
||||
)
|
||||
algo_map = {a.name: a for a in algorithms}
|
||||
return comm_group, algo_map, scratch
|
||||
|
||||
|
||||
def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, nthreads_per_block=0):
|
||||
"""Run allreduce in-place on buffer and return a copy of the result."""
|
||||
ret = algo.execute(
|
||||
comm=comm_group.communicator,
|
||||
input_buffer=buffer.data.ptr,
|
||||
output_buffer=buffer.data.ptr,
|
||||
input_size=buffer.nbytes,
|
||||
output_size=buffer.nbytes,
|
||||
dtype=dtype,
|
||||
op=ReduceOp.SUM,
|
||||
stream=cp.cuda.get_current_stream().ptr,
|
||||
nblocks=nblocks,
|
||||
nthreads_per_block=nthreads_per_block,
|
||||
symmetric_memory=True,
|
||||
accum_dtype=accum_dtype,
|
||||
)
|
||||
cp.cuda.Device().synchronize()
|
||||
assert ret == 0, f"Allreduce failed with error code {ret}"
|
||||
return buffer.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: FP8 E4M3 accumulation correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@parametrize_mpi_groups(8)
|
||||
@pytest.mark.parametrize(
|
||||
"algo_name",
|
||||
[
|
||||
"default_allreduce_packet",
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576])
|
||||
def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
"""Verify that FP8 E4M3 allreduce with higher-precision accumulation is at
|
||||
least as accurate as native FP8 accumulation, across all algorithm variants."""
|
||||
rank = mpi_group.comm.rank
|
||||
world_size = mpi_group.comm.size
|
||||
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
algo = algo_map[algo_name]
|
||||
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("fp8_native", DataType.float8_e4m3),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy and fullmesh need explicit block/thread counts
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
nt = 1024
|
||||
elif "fullmesh" in algo_name:
|
||||
nb = 35
|
||||
nt = 512
|
||||
else:
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
|
||||
src_f32 = cp.clip(src_f32, -240.0, 240.0)
|
||||
src_fp8 = float_to_e4m3fn(src_f32)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_fp8
|
||||
cp.cuda.Device().synchronize()
|
||||
|
||||
# Run allreduce
|
||||
result = run_allreduce(
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
result_f32 = e4m3fn_to_float(result)
|
||||
|
||||
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
|
||||
rank_data = cp.clip(rank_data, -240.0, 240.0)
|
||||
rank_data_fp8 = float_to_e4m3fn(rank_data)
|
||||
ref_f32 += e4m3fn_to_float(rank_data_fp8)
|
||||
|
||||
# Compute errors
|
||||
abs_err = cp.abs(result_f32 - ref_f32)
|
||||
mean_abs_err = float(cp.mean(abs_err))
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
# Reset between runs
|
||||
algo.reset()
|
||||
|
||||
# Higher-precision accumulation should be at least as accurate as native fp8
|
||||
assert (
|
||||
errors["float16"] <= errors["fp8_native"] + 1e-6
|
||||
), f"float16 accum ({errors['float16']:.6f}) worse than native ({errors['fp8_native']:.6f})"
|
||||
assert (
|
||||
errors["float32"] <= errors["fp8_native"] + 1e-6
|
||||
), f"float32 accum ({errors['float32']:.6f}) worse than native ({errors['fp8_native']:.6f})"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: FP8 E4M3B15 accumulation correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@parametrize_mpi_groups(8)
|
||||
@pytest.mark.parametrize(
|
||||
"algo_name",
|
||||
[
|
||||
"default_allreduce_packet",
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 65536])
|
||||
def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
"""Verify that FP8 E4M3B15 allreduce with higher-precision accumulation is at
|
||||
least as accurate as native E4M3B15 accumulation."""
|
||||
rank = mpi_group.comm.rank
|
||||
world_size = mpi_group.comm.size
|
||||
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
|
||||
algo = algo_map[algo_name]
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("e4m3b15_native", DataType.float8_e4m3b15),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy needs explicit block/thread counts, scaled to data size
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
nt = 1024
|
||||
else:
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
src_uint8 = raw | signs
|
||||
# Fix negative zero -> positive zero
|
||||
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_uint8
|
||||
cp.cuda.Device().synchronize()
|
||||
|
||||
# Run allreduce
|
||||
result = run_allreduce(
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3b15,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
|
||||
# Decode result
|
||||
result_f32 = e4m3b15_to_float(result)
|
||||
|
||||
# Compute float32 reference
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
bits_r = raw_r | signs_r
|
||||
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
|
||||
ref_f32 += e4m3b15_to_float(bits_r)
|
||||
|
||||
# Clamp reference to e4m3b15 representable range
|
||||
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
|
||||
|
||||
# Compute errors (only on valid entries)
|
||||
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
|
||||
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
|
||||
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
algo.reset()
|
||||
|
||||
# Higher-precision accumulation should be at least as accurate as native
|
||||
assert (
|
||||
errors["float16"] <= errors["e4m3b15_native"] + 1e-8
|
||||
), f"float16 accum ({errors['float16']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
|
||||
assert (
|
||||
errors["float32"] <= errors["e4m3b15_native"] + 1e-8
|
||||
), f"float32 accum ({errors['float32']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
|
||||
@@ -162,13 +162,10 @@ def create_connection(group: CommGroup, connection_type: str):
|
||||
def create_group_and_connection(mpi_group: MpiGroup, connection_type: str):
|
||||
if (connection_type == "NVLink" or connection_type == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink/nvls for cross node")
|
||||
if connection_type == "IB" and os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0":
|
||||
pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1")
|
||||
group = CommGroup(mpi_group.comm)
|
||||
try:
|
||||
connection = create_connection(group, connection_type)
|
||||
except Error as e:
|
||||
if connection_type == "IB" and e.args[0] == ErrorCode.InvalidUsage:
|
||||
pytest.skip("IB not supported on this node")
|
||||
raise
|
||||
connection = create_connection(group, connection_type)
|
||||
return group, connection
|
||||
|
||||
|
||||
@@ -281,6 +278,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, connection_type: str,
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
if os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0":
|
||||
pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1")
|
||||
group = CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
@@ -301,6 +300,8 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
if os.environ.get("MSCCLPP_DISABLE_IB_TESTS", "0") != "0":
|
||||
pytest.skip("IB tests are disabled via MSCCLPP_DISABLE_IB_TESTS=1")
|
||||
group = CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
|
||||
Reference in New Issue
Block a user