Python bindings (#125)

Co-authored-by: Olli Saarikivi <olsaarik@microsoft.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
Saeed Maleki
2023-07-19 00:35:54 -07:00
committed by GitHub
parent 2e1645782e
commit e7d5e652df
58 changed files with 785 additions and 1263 deletions

5
.black
View File

@@ -1,5 +0,0 @@
[tool.black]
line-length = 120
target-version = ['py38']
include = '\.pyi?$'
extend-exclude = 'python/'

View File

@@ -44,10 +44,14 @@ jobs:
tar xzf /tmp/cmake-3.26.4-linux-x86_64.tar.gz -C /tmp
sudo ln -s /tmp/cmake-3.26.4-linux-x86_64/bin/cmake /usr/bin/cmake
- name: Dubious ownership exception
run: |
git config --global --add safe.directory /__w/mscclpp/mscclpp
- name: Autobuild
uses: github/codeql-action/autobuild@v2
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{matrix.language}}"
category: "/language:${{matrix.language}}/cuda-version:${{matrix.cuda-version}}"

View File

@@ -20,10 +20,8 @@ jobs:
- name: Run cpplint
run: |
CPPSOURCES=$(find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
PYTHONCPPSOURCES=$(find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
CPPSOURCES=$(find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*")
clang-format -style=file --verbose --Werror --dry-run ${CPPSOURCES}
clang-format --dry-run ${PYTHONCPPSOURCES}
pylint:
runs-on: ubuntu-20.04
@@ -45,7 +43,7 @@ jobs:
with:
black: true
black_auto_fix: false
black_args: "--config .black --check"
black_args: "--config pyproject.toml --check"
spelling:
runs-on: ubuntu-20.04

View File

@@ -8,7 +8,7 @@ set(MSCCLPP_PATCH "0")
set(MSCCLPP_SOVERSION ${MSCCLPP_MAJOR})
set(MSCCLPP_VERSION "${MSCCLPP_MAJOR}.${MSCCLPP_MINOR}.${MSCCLPP_PATCH}")
cmake_minimum_required(VERSION 3.26)
cmake_minimum_required(VERSION 3.25)
project(mscclpp LANGUAGES CUDA CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
@@ -23,6 +23,8 @@ include(${PROJECT_SOURCE_DIR}/cmake/AddClangFormatTargets.cmake)
# Options
option(ENABLE_TRACE "Enable tracing" OFF)
option(USE_NPKIT "Use NPKIT" ON)
option(BUILD_TESTS "Build tests" ON)
option(BUILD_PYTHON_BINDINGS "Build Python bindings" ON)
option(ALLOW_GDRCOPY "Use GDRCopy, if available" OFF)
# Find CUDAToolkit. Set CUDA flags based on the detected CUDA version
@@ -51,25 +53,48 @@ if(ALLOW_GDRCOPY)
find_package(GDRCopy)
endif()
# libmscclpp
add_library(mscclpp SHARED)
target_include_directories(mscclpp
add_library(mscclpp_obj OBJECT)
target_include_directories(mscclpp_obj
PRIVATE
${CUDAToolkit_INCLUDE_DIRS}
${IBVERBS_INCLUDE_DIRS}
${NUMA_INCLUDE_DIRS}
${GDRCOPY_INCLUDE_DIRS})
target_link_libraries(mscclpp PRIVATE ${CUDA_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} ${GDRCOPY_LIBRARIES})
set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
target_link_libraries(mscclpp_obj PRIVATE ${CUDA_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} ${GDRCOPY_LIBRARIES})
set_target_properties(mscclpp_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
if(ENABLE_TRACE)
target_compile_definitions(mscclpp PRIVATE ENABLE_TRACE)
target_compile_definitions(mscclpp_obj PRIVATE ENABLE_TRACE)
endif()
if(USE_NPKIT)
target_compile_definitions(mscclpp PRIVATE ENABLE_NPKIT)
target_compile_definitions(mscclpp_obj PRIVATE ENABLE_NPKIT)
endif()
if(ALLOW_GDRCOPY AND GDRCOPY_FOUND)
target_compile_definitions(mscclpp_obj PRIVATE MSCCLPP_USE_GDRCOPY)
target_link_libraries(mscclpp_obj PRIVATE MSCCLPP::gdrcopy)
endif()
# libmscclpp
add_library(mscclpp SHARED)
target_link_libraries(mscclpp PUBLIC mscclpp_obj)
set_target_properties(mscclpp PROPERTIES VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
add_library(mscclpp_static STATIC)
target_link_libraries(mscclpp_static PUBLIC mscclpp_obj)
set_target_properties(mscclpp_static PROPERTIES VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
add_subdirectory(include)
add_subdirectory(src)
install(TARGETS mscclpp LIBRARY DESTINATION lib)
install(TARGETS mscclpp mscclpp_static
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)
install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION include FILES_MATCHING PATTERN "*.hpp")
# Tests
add_subdirectory(test)
if (BUILD_TESTS)
add_subdirectory(test)
endif()
# Python bindings
if(BUILD_PYTHON_BINDINGS)
add_subdirectory(python)
endif()

View File

@@ -61,7 +61,7 @@ Some in-kernel communication interfaces of MSCCL++ send requests (called trigger
```cpp
// Bootstrap: initialize control-plane connections between all ranks
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
// Create a communicator for connection setup
mscclpp::Communicator comm(bootstrap);
// Setup connections here using `comm`

View File

@@ -6,7 +6,7 @@
find_program(CLANG_FORMAT clang-format)
if(CLANG_FORMAT)
message(STATUS "Found clang-format: ${CLANG_FORMAT}")
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/test)
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test)
add_custom_target(check-format ALL
COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
)

View File

@@ -20,10 +20,14 @@ RUN rm -rf build && \
mkdir build && \
cd build && \
${CMAKE_HOME}/bin/cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${MSCCLPP_HOME} .. && \
make -j mscclpp && \
make -j mscclpp mscclpp_static && \
make install/fast && \
strip ${MSCCLPP_HOME}/lib/libmscclpp.so.[0-9]*.[0-9]*.[0-9]*
# Install MSCCL++ Python bindings
WORKDIR ${MSCCLPP_SRC_DIR}
RUN python3.8 -m pip install .
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${MSCCLPP_HOME}/lib"
RUN echo LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" >> /etc/environment

View File

@@ -2,5 +2,7 @@
# Licensed under the MIT license.
file(GLOB_RECURSE HEADERS CONFIGURE_DEPENDS *.hpp)
target_sources(mscclpp PUBLIC FILE_SET HEADERS FILES ${HEADERS})
install(TARGETS mscclpp FILE_SET HEADERS)
target_sources(mscclpp_obj PUBLIC FILE_SET HEADERS FILES ${HEADERS})
if(NOT SKBUILD)
install(TARGETS mscclpp FILE_SET HEADERS)
endif()

View File

@@ -9,6 +9,7 @@
#define MSCCLPP_PATCH 0
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
#include <array>
#include <bitset>
#include <future>
#include <memory>
@@ -21,15 +22,13 @@ namespace mscclpp {
#define MSCCLPP_UNIQUE_ID_BYTES 128
/// Unique ID for a process. This is a MSCCLPP_UNIQUE_ID_BYTES byte array that uniquely identifies a process.
struct UniqueId {
char internal[MSCCLPP_UNIQUE_ID_BYTES];
};
using UniqueId = std::array<uint8_t, MSCCLPP_UNIQUE_ID_BYTES>;
/// Base class for bootstrappers.
class BaseBootstrap {
/// Base class for bootstraps.
class Bootstrap {
public:
BaseBootstrap(){};
virtual ~BaseBootstrap() = default;
Bootstrap(){};
virtual ~Bootstrap() = default;
virtual int getRank() = 0;
virtual int getNranks() = 0;
virtual void send(void* data, int size, int peer, int tag) = 0;
@@ -41,32 +40,32 @@ class BaseBootstrap {
void recv(std::vector<char>& data, int peer, int tag);
};
/// A native implementation of the bootstrapper.
class Bootstrap : public BaseBootstrap {
/// A native implementation of the bootstrap using TCP sockets.
class TcpBootstrap : public Bootstrap {
public:
/// Construct a Bootstrap.
/// Constructor.
/// @param rank The rank of the process.
/// @param nRanks The total number of ranks.
Bootstrap(int rank, int nRanks);
TcpBootstrap(int rank, int nRanks);
/// Destroy the Bootstrap.
~Bootstrap();
/// Destructor.
~TcpBootstrap();
/// Create a random unique ID and store it in the Bootstrap.
/// Create a random unique ID and store it in the @ref TcpBootstrap.
/// @return The created unique ID.
UniqueId createUniqueId();
/// Return the unique ID stored in the Bootstrap.
/// @return The unique ID stored in the Bootstrap.
/// Return the unique ID stored in the @ref TcpBootstrap.
/// @return The unique ID stored in the @ref TcpBootstrap.
UniqueId getUniqueId() const;
/// Initialize the Bootstrap with a given unique ID.
/// @param uniqueId The unique ID to initialize the Bootstrap with.
/// Initialize the @ref TcpBootstrap with a given unique ID.
/// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with.
void initialize(UniqueId uniqueId);
/// Initialize the Bootstrap with a string formatted as "ip:port" or "interface:ip:port".
/// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port".
/// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port".
void initialize(std::string ifIpPortTrio);
void initialize(const std::string& ifIpPortTrio);
/// Return the rank of the process.
int getRank() override;
@@ -109,9 +108,9 @@ class Bootstrap : public BaseBootstrap {
void barrier() override;
private:
/// Implementation class for Bootstrap.
/// Implementation class for @ref TcpBootstrap.
class Impl;
/// Pointer to the implementation class for Bootstrap.
/// Pointer to the implementation class for @ref TcpBootstrap.
std::unique_ptr<Impl> pimpl_;
};
@@ -421,13 +420,13 @@ struct Setuppable {
/// being set up within the same @ref Communicator::setup() call.
///
/// @param bootstrap A shared pointer to the bootstrap implementation.
virtual void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap);
virtual void beginSetup(std::shared_ptr<Bootstrap> bootstrap);
/// Called inside @ref Communicator::setup() after all calls to @ref beginSetup() of all @ref Setuppable objects that
/// are being set up within the same @ref Communicator::setup() call.
///
/// @param bootstrap A shared pointer to the bootstrap implementation.
virtual void endSetup(std::shared_ptr<BaseBootstrap> bootstrap);
virtual void endSetup(std::shared_ptr<Bootstrap> bootstrap);
};
/// A non-blocking future that can be used to check if a value is ready and retrieve it.
@@ -484,16 +483,16 @@ class Communicator {
public:
/// Initializes the communicator with a given bootstrap implementation.
///
/// @param bootstrap An implementation of the BaseBootstrap that the communicator will use.
Communicator(std::shared_ptr<BaseBootstrap> bootstrap);
/// @param bootstrap An implementation of the Bootstrap that the communicator will use.
Communicator(std::shared_ptr<Bootstrap> bootstrap);
/// Destroy the communicator.
~Communicator();
/// Returns the bootstrapper held by this communicator.
/// Returns the bootstrap held by this communicator.
///
/// @return std::shared_ptr<BaseBootstrap> The bootstrapper held by this communicator.
std::shared_ptr<BaseBootstrap> bootstrapper();
/// @return std::shared_ptr<Bootstrap> The bootstrap held by this communicator.
std::shared_ptr<Bootstrap> bootstrap();
/// Register a region of GPU memory for use in this communicator.
///

View File

@@ -134,7 +134,7 @@ union ChannelTrigger {
struct ProxyChannel {
ProxyChannel() = default;
ProxyChannel(SemaphoreId SemaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, DeviceProxyFifo fifo);
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, DeviceProxyFifo fifo);
ProxyChannel(const ProxyChannel& other) = default;

20
pyproject.toml Normal file
View File

@@ -0,0 +1,20 @@
[build-system]
requires = ["scikit-build-core"]
build-backend = "scikit_build_core.build"
[project]
name = "mscclpp"
version = "0.2.0"
[tool.scikit-build]
cmake.minimum-version = "3.25.0"
build-dir = "build/{wheel_tag}"
[tool.scikit-build.cmake.define]
BUILD_PYTHON_BINDINGS = "ON"
BUILD_TESTS = "OFF"
[tool.black]
line-length = 120
target-version = ['py38']
include = '\.pyi?$'

2
python/.gitignore vendored
View File

@@ -1,2 +0,0 @@
.*.swp
.venv/

View File

@@ -1,61 +1,16 @@
project(mscclpp)
cmake_minimum_required(VERSION 3.18...3.22)
find_package(Python 3.9 COMPONENTS Interpreter Development.Module REQUIRED)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif ()
# Create CMake targets for all Python components needed by nanobind
if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.26)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module Development.SABIModule REQUIRED)
else ()
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
endif ()
# Detect the installed nanobind package and import it into CMake
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import nanobind; print(nanobind.cmake_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
set(CUDA_DIR "/usr/local/cuda")
set(MSCCLPP_DIR ${CMAKE_CURRENT_LIST_DIR}/../build)
nanobind_add_module(
_py_mscclpp
NOSTRIP
NB_STATIC
src/_py_mscclpp.cpp
)
target_include_directories(
_py_mscclpp
PUBLIC
${CUDA_DIR}/include
${MSCCLPP_DIR}/include
)
target_link_directories(
_py_mscclpp
PUBLIC
${CUDA_DIR}/lib
${MSCCLPP_DIR}/lib
)
target_link_libraries(
_py_mscclpp
PUBLIC
mscclpp
)
add_custom_target(build-package ALL DEPENDS _py_mscclpp)
add_custom_command(
TARGET build-package POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${MSCCLPP_DIR}/lib/libmscclpp.so
${CMAKE_CURRENT_BINARY_DIR})
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
include(FetchContent)
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0)
FetchContent_MakeAvailable(nanobind)
nanobind_add_module(mscclpp_py core_py.cpp error_py.cpp proxy_channel_py.cpp fifo_py.cpp semaphore_py.cpp
config_py.cpp utils_py.cpp)
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME mscclpp)
target_link_libraries(mscclpp_py PRIVATE mscclpp_static)
target_include_directories(mscclpp_py PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
if (SKBUILD)
install(TARGETS mscclpp_py LIBRARY DESTINATION ${SKBUILD_PLATLIB_DIR})
endif()

View File

@@ -1,4 +0,0 @@
test:
./test.sh

View File

@@ -1,30 +0,0 @@
# Python bindings
Test instructions:
* Compile the `libmscclpp.so` library.
* Install `cmake` verion >= 3.18
* setup a python virtual env
* `pip install -r dev-requirements.txt`
* `./tesh.sh`
## Run CI:
```bash
./ci.sh
```
## Build a wheel:
Setup dev environment, then:
```bash
python setup.py bdist_wheel
```
## Installing mpi and numa libs.
```
## numctl
apt install -y numactl libnuma-dev libnuma1
```

View File

@@ -1,24 +0,0 @@
#!/bin/bash
# CI hook script.
set -ex
# CD to this directory.
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR
# clean env
rm -rf .venv build
# setup a python virtual env
python -m venv .venv
# activate the virtual env
source .venv/bin/activate
# install venv deps.
pip install -r dev-requirements.txt
# run the build and test.
./test.sh

16
python/config_py.cpp Normal file
View File

@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <mscclpp/config.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_config(nb::module_& m) {
nb::class_<Config>(m, "Config")
.def_static("get_instance", &Config::getInstance, nb::rv_policy::reference)
.def("get_bootstrap_connection_timeout_config", &Config::getBootstrapConnectionTimeoutConfig)
.def("set_bootstrap_connection_timeout_config", &Config::setBootstrapConnectionTimeoutConfig);
}

153
python/core_py.cpp Normal file
View File

@@ -0,0 +1,153 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/operators.h>
#include <nanobind/stl/array.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <mscclpp/core.hpp>
namespace nb = nanobind;
using namespace mscclpp;
extern void register_error(nb::module_& m);
extern void register_proxy_channel(nb::module_& m);
extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_config(nb::module_& m);
extern void register_utils(nb::module_& m);
template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
.def("ready", &NonblockingFuture<T>::ready)
.def("get", &NonblockingFuture<T>::get);
}
void register_core(nb::module_& m) {
nb::class_<Bootstrap>(m, "Bootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
void* data = reinterpret_cast<void*>(ptr);
self->send(data, size, peer, tag);
},
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
.def(
"recv",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
void* data = reinterpret_cast<void*>(ptr);
self->recv(data, size, peer, tag);
},
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
.def("barrier", &Bootstrap::barrier)
.def("send", (void (Bootstrap::*)(const std::vector<char>&, int, int)) & Bootstrap::send, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"))
.def("recv", (void (Bootstrap::*)(std::vector<char>&, int, int)) & Bootstrap::recv, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"));
nb::class_<UniqueId>(m, "UniqueId");
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
.def_static(
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
nb::arg("nRanks"))
.def("create_unique_id", &TcpBootstrap::createUniqueId)
.def("get_unique_id", &TcpBootstrap::getUniqueId)
.def("initialize", (void (TcpBootstrap::*)(UniqueId)) & TcpBootstrap::initialize, nb::arg("uniqueId"))
.def("initialize", (void (TcpBootstrap::*)(const std::string&)) & TcpBootstrap::initialize,
nb::arg("ifIpPortTrio"));
nb::enum_<Transport>(m, "Transport")
.value("Unknown", Transport::Unknown)
.value("CudaIpc", Transport::CudaIpc)
.value("IB0", Transport::IB0)
.value("IB1", Transport::IB1)
.value("IB2", Transport::IB2)
.value("IB3", Transport::IB3)
.value("IB4", Transport::IB4)
.value("IB5", Transport::IB5)
.value("IB6", Transport::IB6)
.value("IB7", Transport::IB7)
.value("NumTransports", Transport::NumTransports);
nb::class_<TransportFlags>(m, "TransportFlags")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def("has", &TransportFlags::has, nb::arg("transport"))
.def("none", &TransportFlags::none)
.def("any", &TransportFlags::any)
.def("all", &TransportFlags::all)
.def("count", &TransportFlags::count)
.def(nb::self |= nb::self)
.def(nb::self | nb::self)
.def(nb::self | Transport())
.def(nb::self &= nb::self)
.def(nb::self & nb::self)
.def(nb::self & Transport())
.def(nb::self ^= nb::self)
.def(nb::self ^ nb::self)
.def(nb::self ^ Transport())
.def(~nb::self)
.def(nb::self == nb::self)
.def(nb::self != nb::self);
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
.def(nb::init<>())
.def("data", &RegisteredMemory::data)
.def("size", &RegisteredMemory::size)
.def("rank", &RegisteredMemory::rank)
.def("transports", &RegisteredMemory::transports)
.def("serialize", &RegisteredMemory::serialize)
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
nb::class_<Connection>(m, "Connection")
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
nb::arg("size"))
.def(
"update_and_sync",
[](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) {
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
},
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
.def("flush", &Connection::flush)
.def("remote_rank", &Connection::remoteRank)
.def("tag", &Connection::tag)
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>>(), nb::arg("bootstrap"))
.def("bootstrap", &Communicator::bootstrap)
.def(
"register_memory",
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("transport"))
.def("setup", &Communicator::setup);
}
NB_MODULE(mscclpp, m) {
register_error(m);
register_proxy_channel(m);
register_fifo(m);
register_semaphore(m);
register_config(m);
register_utils(m);
register_core(m);
}

View File

@@ -1,11 +0,0 @@
black
isort
wheel
pytest
PyHamcrest
nanobind
torch

41
python/error_py.cpp Normal file
View File

@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <mscclpp/errors.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_error(nb::module_& m) {
nb::enum_<ErrorCode>(m, "ErrorCode")
.value("SystemError", ErrorCode::SystemError)
.value("InternalError", ErrorCode::InternalError)
.value("RemoteError", ErrorCode::RemoteError)
.value("InvalidUsage", ErrorCode::InvalidUsage)
.value("Timeout", ErrorCode::Timeout)
.value("Aborted", ErrorCode::Aborted);
nb::class_<BaseError>(m, "BaseError")
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
.def("get_error_code", &BaseError::getErrorCode)
.def("what", &BaseError::what);
nb::class_<Error, BaseError>(m, "Error")
.def(nb::init<const std::string&, ErrorCode>(), nb::arg("message"), nb::arg("errorCode"))
.def("get_error_code", &Error::getErrorCode);
nb::class_<SysError, BaseError>(m, "SysError")
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
nb::class_<CudaError, BaseError>(m, "CudaError")
.def(nb::init<const std::string&, cudaError_t>(), nb::arg("message"), nb::arg("errorCode"));
nb::class_<CuError, BaseError>(m, "CuError")
.def(nb::init<const std::string&, CUresult>(), nb::arg("message"), nb::arg("errorCode"));
nb::class_<IbError, BaseError>(m, "IbError")
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
}

View File

@@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import mscclpp
import argparse
import multiprocessing as mp
import logging
import torch
import sys
IB_TRANSPORTS = [
mscclpp.Transport.IB0,
mscclpp.Transport.IB1,
mscclpp.Transport.IB2,
mscclpp.Transport.IB3,
mscclpp.Transport.IB4,
mscclpp.Transport.IB5,
mscclpp.Transport.IB6,
mscclpp.Transport.IB7,
]
def setup_connections(comm, rank, world_size, element_size, proxy_service):
simple_proxy_channels = []
connections = []
remote_memories = []
memory = torch.zeros(element_size, dtype=torch.int32)
memory = memory.to("cuda")
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
reg_mem = comm.register_memory(ptr, size, transport_flag)
for r in range(world_size):
if r == rank:
continue
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
connections.append(conn)
comm.send_memory_on_setup(reg_mem, r, 0)
remote_mem = comm.recv_memory_on_setup(r, 0)
remote_memories.append(remote_mem)
comm.setup()
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(proxy_channel)
comm.setup()
return simple_proxy_channels
def run(rank, args):
world_size = args.gpu_number
torch.cuda.set_device(rank)
boot = mscclpp.TcpBootstrap.create(rank, world_size)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
proxy_service = mscclpp.ProxyService(comm)
logging.info("Rank: %d, setting up connections", rank)
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
logging.info("Rank: %d, starting proxy service", rank)
proxy_service.start_proxy()
def main():
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("-g", "--gpu_number", type=int, default=2)
args = parser.parse_args()
processes = []
for rank in range(args.gpu_number):
p = mp.Process(target=run, args=(rank, args))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()

12
python/examples/config.py Normal file
View File

@@ -0,0 +1,12 @@
import mscclpp
def main():
config = mscclpp.Config.get_instance()
config.set_bootstrap_connection_timeout_config(15)
timeout = config.get_bootstrap_connection_timeout_config()
assert timeout == 15
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,81 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import mscclpp
import argparse
import time
def main(args):
if args.root:
rank = 0
else:
rank = 1
boot = mscclpp.TcpBootstrap.create(rank, 2)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
if args.gpu:
import torch
print("Allocating GPU memory")
memory = torch.zeros(args.num_elements, dtype=torch.int32)
memory = memory.to("cuda")
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
else:
from array import array
print("Allocating host memory")
memory = array("i", [0] * args.num_elements)
ptr, elements = memory.buffer_info()
size = elements * memory.itemsize
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
conn = comm.connect_on_setup((rank + 1) % 2, 0, mscclpp.Transport.IB0)
other_reg_mem = None
if rank == 0:
other_reg_mem = comm.recv_memory_on_setup((rank + 1) % 2, 0)
else:
comm.send_memory_on_setup(my_reg_mem, (rank + 1) % 2, 0)
comm.setup()
if rank == 0:
other_reg_mem = other_reg_mem.get()
if rank == 0:
for i in range(args.num_elements):
memory[i] = i + 1
conn.write(other_reg_mem, 0, my_reg_mem, 0, size)
print("Done sending")
else:
print("Checking for correctness")
# polling
for _ in range(args.polling_num):
all_correct = True
for i in range(args.num_elements):
if memory[i] != i + 1:
all_correct = False
print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}")
break
if all_correct:
print("All data matched expected values")
break
else:
time.sleep(0.1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("if_ip_port_trio", type=str)
parser.add_argument("-r", "--root", action="store_true")
parser.add_argument("-n", "--num-elements", type=int, default=10)
parser.add_argument("--gpu", action="store_true")
parser.add_argument("--polling_num", type=int, default=100)
args = parser.parse_args()
main(args)

14
python/examples/utils.py Normal file
View File

@@ -0,0 +1,14 @@
import mscclpp
import time
def main():
timer = mscclpp.Timer()
timer.reset()
time.sleep(2)
assert timer.elapsed() >= 2000000
if __name__ == "__main__":
main()

25
python/fifo_py.cpp Normal file
View File

@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <mscclpp/fifo.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_fifo(nb::module_& m) {
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
nb::class_<DeviceProxyFifo>(m, "DeviceProxyFifo")
.def_rw("triggers", &DeviceProxyFifo::triggers)
.def_rw("tail_replica", &DeviceProxyFifo::tailReplica)
.def_rw("head", &DeviceProxyFifo::head);
nb::class_<HostProxyFifo>(m, "HostProxyFifo")
.def(nb::init<>())
.def("poll", &HostProxyFifo::poll, nb::arg("trigger"))
.def("pop", &HostProxyFifo::pop)
.def("flush_tail", &HostProxyFifo::flushTail, nb::arg("sync") = false)
.def("device_fifo", &HostProxyFifo::deviceFifo);
}

View File

@@ -1,9 +0,0 @@
#!/bin/bash
set -ex
isort src
black src
clang-format -i $(find src -name '*.cpp' -or -name '*.h')

View File

@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <mscclpp/proxy_channel.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_proxy_channel(nb::module_& m) {
nb::class_<BaseProxyService>(m, "BaseProxyService")
.def("start_proxy", &BaseProxyService::startProxy)
.def("stop_proxy", &BaseProxyService::stopProxy);
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
.def(nb::init<Communicator&>(), nb::arg("comm"))
.def("start_proxy", &ProxyService::startProxy)
.def("stop_proxy", &ProxyService::stopProxy)
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("connection"))
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
.def("device_channel", &ProxyService::deviceChannel, nb::arg("id"));
nb::class_<ProxyChannel>(m, "ProxyChannel")
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, DeviceProxyFifo>(), nb::arg("semaphoreId"),
nb::arg("semaphore"), nb::arg("fifo"));
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxy_chan"), nb::arg("dst"), nb::arg("src"))
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxy_chan"));
};

42
python/semaphore_py.cpp Normal file
View File

@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <mscclpp/semaphore.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_semaphore(nb::module_& m) {
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
host2DeviceSemaphore
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2DeviceSemaphore::connection)
.def("signal", &Host2DeviceSemaphore::signal)
.def("device_handle", &Host2DeviceSemaphore::deviceHandle);
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2HostSemaphore::connection)
.def("signal", &Host2HostSemaphore::signal)
.def("wait", &Host2HostSemaphore::wait);
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
smDevice2DeviceSemaphore
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("device_handle", &SmDevice2DeviceSemaphore::deviceHandle);
nb::class_<SmDevice2DeviceSemaphore::DeviceHandle>(smDevice2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
}

View File

@@ -1,83 +0,0 @@
#!/usr/bin/env python
import os
import shutil
import sys
import logging
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
import subprocess
THIS_DIR = os.path.abspath(os.path.dirname(__file__))
class CustomExt(Extension):
def __init__(self, name):
# don't invoke the original build_ext for this special extension
super().__init__(name, sources=[])
class custom_build_ext(build_ext):
def run(self):
for ext in self.extensions:
if isinstance(ext, CustomExt):
self.build_extension(ext)
else:
super().run()
def build_extension(self, ext):
if sys.platform == "darwin":
return
# these dirs will be created in build_py, so if you don't have
# any python sources to bundle, the dirs will be missing
build_temp = os.path.abspath(self.build_temp)
os.makedirs(build_temp, exist_ok=True)
try:
subprocess.check_output(
["cmake", "-S", THIS_DIR, "-B", build_temp],
stderr=subprocess.STDOUT,
)
subprocess.check_output(
["cmake", "--build", build_temp],
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as e:
logging.error(e.output.decode())
raise
libname = os.path.basename(self.get_ext_fullpath(ext.name))
target_dir = os.path.join(
os.path.dirname(self.get_ext_fullpath(ext.name)),
"mscclpp",
)
shutil.copy(
os.path.join(build_temp, "libmscclpp.so"),
target_dir,
)
shutil.copy(
os.path.join(build_temp, libname),
target_dir,
)
setup(
name='mscclpp',
version='0.1.0',
description='Python bindings for mscclpp',
# packages=['mscclpp'],
package_dir={'': 'src'},
packages=find_packages(where='./src'),
ext_modules=[CustomExt('_py_mscclpp')],
cmdclass={
"build_ext": custom_build_ext,
},
install_requires=[
'torch',
'nanobind',
],
)

View File

@@ -1,7 +0,0 @@
---
Language : Cpp
BasedOnStyle : google
BinPackParameters: false
BinPackArguments : false
AlignAfterOpenBracket : AlwaysBreak
...

View File

@@ -1,405 +0,0 @@
#include <cuda_runtime.h>
#include <mscclpp.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
namespace nb = nanobind;
using namespace nb::literals;
// This is a poorman's substitute for std::format, which is a C++20 feature.
template <typename... Args>
std::string string_format(const std::string& format, Args... args) {
// Shutup format warning.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wformat-security"
// Dry-run to the get the buffer size:
// Extra space for '\0'
int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1;
if (size_s <= 0) {
throw std::runtime_error("Error during formatting.");
}
// allocate buffer
auto size = static_cast<size_t>(size_s);
std::unique_ptr<char[]> buf(new char[size]);
// actually format
std::snprintf(buf.get(), size, format.c_str(), args...);
// Bulid the return string.
// We don't want the '\0' inside
return std::string(buf.get(), buf.get() + size - 1);
#pragma GCC diagnostic pop
}
// Maybe return the value, maybe throw an exception.
template <typename... Args>
void checkResult(
mscclppResult_t status, const std::string& format, Args... args) {
switch (status) {
case mscclppSuccess:
return;
case mscclppUnhandledCudaError:
case mscclppSystemError:
case mscclppInternalError:
case mscclppRemoteError:
case mscclppInProgress:
case mscclppNumResults:
throw std::runtime_error(
string_format(format, args...) + " : " +
std::string(mscclppGetErrorString(status)));
case mscclppInvalidArgument:
case mscclppInvalidUsage:
default:
throw std::invalid_argument(
string_format(format, args...) + " : " +
std::string(mscclppGetErrorString(status)));
}
}
#define RETRY(C, ...) \
{ \
mscclppResult_t res; \
do { \
res = (C); \
} while (res == mscclppInProgress); \
checkResult(res, __VA_ARGS__); \
}
// Maybe return the value, maybe throw an exception.
template <typename Val, typename... Args>
Val maybe(
mscclppResult_t status, Val val, const std::string& format, Args... args) {
checkResult(status, format, args...);
return val;
}
// Wrapper around connection state.
struct _Comm {
int _rank;
int _world_size;
mscclppComm_t _handle;
bool _is_open;
bool _proxies_running;
public:
_Comm(int rank, int world_size, mscclppComm_t handle)
: _rank(rank),
_world_size(world_size),
_handle(handle),
_is_open(true),
_proxies_running(false) {}
~_Comm() { close(); }
// Close should be safe to call on a closed handle.
void close() {
if (_is_open) {
if (_proxies_running) {
mscclppProxyStop(_handle);
_proxies_running = false;
}
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
_handle = NULL;
_is_open = false;
_rank = -1;
_world_size = -1;
}
}
void check_open() {
if (!_is_open) {
throw std::invalid_argument("_Comm is not open");
}
}
};
struct _P2PHandle {
struct mscclppRegisteredMemoryP2P _rmP2P;
struct mscclppIbMr _ibmr;
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
_P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) {
_rmP2P = p2p;
if (_rmP2P.IbMr != nullptr) {
_ibmr = *_rmP2P.IbMr;
_rmP2P.IbMr = &_ibmr;
}
}
};
nb::callable _log_callback;
void _LogHandler(const char* msg) {
if (_log_callback) {
nb::gil_scoped_acquire guard;
_log_callback(msg);
}
}
static const std::string DOC_MscclppUniqueId =
"MSCCLPP Unique Id; used by the MPI Interface";
static const std::string DOC__Comm = "MSCCLPP Communications Handle";
static const std::string DOC__P2PHandle = "MSCCLPP P2P MR Handle";
NB_MODULE(_py_mscclpp, m) {
m.doc() = "Python bindings for MSCCLPP: which is not NCCL";
m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES;
m.def("_bind_log_handler", [](nb::callable cb) -> void {
_log_callback = nb::borrow<nb::callable>(cb);
mscclppSetLogHandler(_LogHandler);
});
m.def("_release_log_handler", []() -> void {
_log_callback.reset();
mscclppSetLogHandler(mscclppDefaultLogHandler);
});
nb::enum_<mscclppTransport_t>(m, "TransportType")
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
.value("IB", mscclppTransport_t::mscclppTransportIB);
nb::class_<mscclppUniqueId>(m, "MscclppUniqueId")
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
.def_static(
"from_context",
[]() -> mscclppUniqueId {
mscclppUniqueId uniqueId;
return maybe(
mscclppGetUniqueId(&uniqueId),
uniqueId,
"Failed to get MSCCLP Unique Id.");
},
nb::call_guard<nb::gil_scoped_release>())
.def_static(
"from_bytes",
[](nb::bytes source) -> mscclppUniqueId {
if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) {
throw std::invalid_argument(string_format(
"Requires exactly %d bytes; found %d",
MSCCLPP_UNIQUE_ID_BYTES,
source.size()));
}
mscclppUniqueId uniqueId;
std::memcpy(
uniqueId.internal, source.c_str(), sizeof(uniqueId.internal));
return uniqueId;
})
.def("bytes", [](mscclppUniqueId id) {
return nb::bytes(id.internal, sizeof(id.internal));
});
nb::class_<_P2PHandle>(m, "_P2PHandle")
.def_ro_static("__doc__", &DOC__P2PHandle);
nb::class_<_Comm>(m, "_Comm")
.def_ro_static("__doc__", &DOC__Comm)
.def_static(
"init_rank_from_address",
[](const std::string& address, int rank, int world_size) -> _Comm* {
mscclppComm_t handle;
checkResult(
mscclppCommInitRank(&handle, world_size, address.c_str(), rank),
"Failed to initialize comms: %s rank=%d world_size=%d",
address,
rank,
world_size);
return new _Comm(rank, world_size, handle);
},
nb::rv_policy::take_ownership,
nb::call_guard<nb::gil_scoped_release>(),
"address"_a,
"rank"_a,
"world_size"_a,
"Initialize comms given an IP address, rank, and world_size")
.def_static(
"init_rank_from_id",
[](const mscclppUniqueId& id, int rank, int world_size) -> _Comm* {
mscclppComm_t handle;
checkResult(
mscclppCommInitRankFromId(&handle, world_size, id, rank),
"Failed to initialize comms: %02X%s rank=%d world_size=%d",
id.internal,
rank,
world_size);
return new _Comm(rank, world_size, handle);
},
nb::rv_policy::take_ownership,
nb::call_guard<nb::gil_scoped_release>(),
"id"_a,
"rank"_a,
"world_size"_a,
"Initialize comms given u UniqueID, rank, and world_size")
.def(
"opened",
[](_Comm& comm) -> bool { return comm._is_open; },
"Is this comm object opened?")
.def(
"closed",
[](_Comm& comm) -> bool { return !comm._is_open; },
"Is this comm object closed?")
.def_ro("rank", &_Comm::_rank)
.def_ro("world_size", &_Comm::_world_size)
.def(
"register_buffer",
[](_Comm& comm,
uint64_t local_buff,
uint64_t buff_size) -> std::vector<_P2PHandle> {
comm.check_open();
mscclppRegisteredMemory regMem;
checkResult(
mscclppRegisterBuffer(
comm._handle,
reinterpret_cast<void*>(local_buff),
buff_size,
&regMem),
"Registering buffer failed");
std::vector<_P2PHandle> handles;
for (const auto& p2p : regMem.p2p) {
handles.push_back(_P2PHandle(p2p));
}
return handles;
},
"local_buf"_a,
"buff_size"_a,
nb::call_guard<nb::gil_scoped_release>(),
"Register a buffer for P2P transfers.")
.def(
"connect",
[](_Comm& comm,
int remote_rank,
int tag,
uint64_t local_buff,
uint64_t buff_size,
mscclppTransport_t transport_type) -> void {
comm.check_open();
RETRY(
mscclppConnect(
comm._handle,
remote_rank,
tag,
reinterpret_cast<void*>(local_buff),
buff_size,
transport_type,
NULL // ibDev
),
"Connect failed");
},
"remote_rank"_a,
"tag"_a,
"local_buf"_a,
"buff_size"_a,
"transport_type"_a,
nb::call_guard<nb::gil_scoped_release>(),
"Attach a local buffer to a remote connection.")
.def(
"connection_setup",
[](_Comm& comm) -> void {
comm.check_open();
RETRY(
mscclppConnectionSetup(comm._handle),
"Failed to setup MSCCLPP connection");
},
nb::call_guard<nb::gil_scoped_release>(),
"Run connection setup for MSCCLPP.")
.def(
"launch_proxies",
[](_Comm& comm) -> void {
comm.check_open();
if (comm._proxies_running) {
throw std::invalid_argument("Proxy Threads Already Running");
}
checkResult(
mscclppProxyLaunch(comm._handle),
"Failed to launch MSCCLPP proxy");
comm._proxies_running = true;
},
nb::call_guard<nb::gil_scoped_release>(),
"Start the MSCCLPP proxy.")
.def(
"stop_proxies",
[](_Comm& comm) -> void {
comm.check_open();
if (comm._proxies_running) {
checkResult(
mscclppProxyStop(comm._handle),
"Failed to stop MSCCLPP proxy");
comm._proxies_running = false;
}
},
nb::call_guard<nb::gil_scoped_release>(),
"Start the MSCCLPP proxy.")
.def("close", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
.def("__del__", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
.def(
"bootstrap_all_gather_int",
[](_Comm& comm, int val) -> std::vector<int> {
std::vector<int> buf(comm._world_size);
buf[comm._rank] = val;
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
return buf;
},
nb::call_guard<nb::gil_scoped_release>(),
"val"_a,
"all-gather ints over the bootstrap connection.")
.def(
"all_gather_bytes",
[](_Comm& comm, nb::bytes& item) -> std::vector<nb::bytes> {
// First, all-gather the sizes of all bytes.
std::vector<size_t> sizes(comm._world_size);
sizes[comm._rank] = item.size();
checkResult(
mscclppBootstrapAllGather(
comm._handle, sizes.data(), sizeof(size_t)),
"bootstrapAllGather failed.");
// Next, find the largest message to send.
size_t max_size = *std::max_element(sizes.begin(), sizes.end());
// Allocate an all-gather buffer large enough for max * world_size.
std::shared_ptr<char[]> data_buf(
new char[max_size * comm._world_size]);
// Copy the local item into the buffer.
std::memcpy(
&data_buf[comm._rank * max_size], item.c_str(), item.size());
// all-gather the data buffer.
checkResult(
mscclppBootstrapAllGather(
comm._handle, data_buf.get(), max_size),
"bootstrapAllGather failed.");
// Build a response vector.
std::vector<nb::bytes> ret;
for (int i = 0; i < comm._world_size; ++i) {
// Copy out the relevant range of each item.
ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i]));
}
return ret;
},
nb::call_guard<nb::gil_scoped_release>(),
"item"_a,
"all-gather bytes over the bootstrap connection; sizes do not need "
"to match.");
}

View File

@@ -1,203 +0,0 @@
import atexit
import json
import logging
import os
import pickle
import re
from typing import Any, Optional, final
logger = logging.getLogger(__file__)
from . import _py_mscclpp
__all__ = (
"Comm",
"MscclppUniqueId",
"MSCCLPP_UNIQUE_ID_BYTES",
"TransportType",
)
_Comm = _py_mscclpp._Comm
_P2PHandle = _py_mscclpp._P2PHandle
TransportType = _py_mscclpp.TransportType
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES
def _mscclpp_log_cb(msg: str) -> None:
"""Log callback hook called from inside _py_mscclpp."""
# Attempt to parse out the original log level:
level = logging.INFO
if match := re.search(r"MSCCLPP (\w+)", msg):
level = logging._nameToLevel.get(match.group(1), logging.INFO)
# actually log the event.
logger.log(level, msg)
# The known log levels used by MSCCLPP.
# Set in os.environ['MSCCLPP_DEBUG'] and only parsed on first init.
MSCCLPP_LOG_LEVELS: set[str] = {
"DEBUG",
"INFO",
"WARN",
"ABORT",
"TRACE",
}
def _setup_logging(level: str = "INFO"):
"""Setup log hooks for the C library."""
level = level.upper()
if level not in MSCCLPP_LOG_LEVELS:
level = "INFO"
os.environ["MSCCLPP_DEBUG"] = level
_py_mscclpp._bind_log_handler(_mscclpp_log_cb)
# needed to prevent a segfault at exit.
atexit.register(_py_mscclpp._release_log_handler)
_setup_logging()
@final
class Comm:
"""Comm object; represents a mscclpp connection."""
_comm: _Comm
@staticmethod
def init_rank_from_address(
address: str,
rank: int,
world_size: int,
*,
port: Optional[int] = None,
):
"""Initialize a Comm from an address.
:param address: the address as a string, with optional port.
:param rank: this Comm's rank.
:param world_size: the total world size.
:param port: (optional) port, appended to address.
:return: a newly initialized Comm.
"""
if port is not None:
address = f"{address}:{port}"
return Comm(
_comm=_Comm.init_rank_from_address(
address=address,
rank=rank,
world_size=world_size,
),
)
def __init__(self, *, _comm: _Comm):
"""Construct a Comm object wrapping an internal _Comm handle."""
self._comm = _comm
def __del__(self) -> None:
self.close()
def close(self) -> None:
"""Close the connection."""
if self._comm:
self._comm.close()
self._comm = None
@property
def rank(self) -> int:
"""Return the rank of the Comm.
Assumes the Comm is open.
"""
return self._comm.rank
@property
def world_size(self) -> int:
"""Return the world_size of the Comm.
Assumes the Comm is open.
"""
return self._comm.world_size
def bootstrap_all_gather_int(self, val: int) -> list[int]:
"""AllGather an int value through the bootstrap interface."""
return self._comm.bootstrap_all_gather_int(val)
def all_gather_bytes(self, item: bytes) -> list[bytes]:
"""AllGather bytes (of different sizes) through the bootstrap interface.
:param item: the bytes object for this rank.
:return: a list of bytes objects; the ret[rank] object will be a new copy.
"""
return self._comm.all_gather_bytes(item)
def all_gather_json(self, item: Any) -> list[Any]:
"""AllGather JSON objects through the bootstrap interface.
:param item: the JSON object for this rank.
:return: a list of JSON objects; the ret[rank] object will be a new copy.
"""
return [
json.loads(b.decode("utf-8"))
for b in self.all_gather_bytes(bytes(json.dumps(item), "utf-8"))
]
def all_gather_pickle(self, item: Any) -> list[Any]:
"""AllGather pickle-able objects through the bootstrap interface.
:param item: the object for this rank.
:return: a list of de-pickled objects. Note, the ret[rank] item will be a new copy.
"""
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
def connect(
self,
remote_rank: int,
tag: int,
data_ptr,
data_size: int,
transport: int,
) -> None:
self._comm.connect(
remote_rank,
tag,
data_ptr,
data_size,
transport,
)
def connection_setup(self) -> None:
self._comm.connection_setup()
def launch_proxies(self) -> None:
self._comm.launch_proxies()
def stop_proxies(self) -> None:
self._comm.stop_proxies()
def register_buffer(
self,
data_ptr,
data_size: int,
) -> list[_P2PHandle]:
return [
P2PHandle(self, h)
for h in self._comm.register_buffer(
data_ptr,
data_size,
)
]
class P2PHandle:
_comm: Comm
_handle: _P2PHandle
def __init__(self, comm: Comm, handle: _P2PHandle):
self._comm = comm
self._handle = handle

View File

@@ -1,86 +0,0 @@
import concurrent.futures
import os
import subprocess
import sys
import unittest
import hamcrest
import mscclpp
MOD_DIR = os.path.dirname(__file__)
TESTS_DIR = os.path.join(MOD_DIR, "tests")
class UniqueIdTest(unittest.TestCase):
def test_no_constructor(self) -> None:
hamcrest.assert_that(
hamcrest.calling(mscclpp.MscclppUniqueId).with_args(),
hamcrest.raises(
TypeError,
"no constructor",
),
)
def test_getUniqueId(self) -> None:
myId = mscclpp.MscclppUniqueId.from_context()
hamcrest.assert_that(
myId.bytes(),
hamcrest.has_length(mscclpp.MSCCLPP_UNIQUE_ID_BYTES),
)
# from_bytes should work
copy = mscclpp.MscclppUniqueId.from_bytes(myId.bytes())
hamcrest.assert_that(
copy.bytes(),
hamcrest.equal_to(myId.bytes()),
)
# bad size
hamcrest.assert_that(
hamcrest.calling(mscclpp.MscclppUniqueId.from_bytes).with_args(b"abc"),
hamcrest.raises(
ValueError,
f"Requires exactly {mscclpp.MSCCLPP_UNIQUE_ID_BYTES} bytes; found 3",
),
)
class CommsTest(unittest.TestCase):
def test_all_gather(self) -> None:
world_size = 2
tasks: list[concurrent.futures.Future[None]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as pool:
for rank in range(world_size):
tasks.append(
pool.submit(
subprocess.check_output,
[
"python",
"-m",
"mscclpp.tests.bootstrap_test",
f"--rank={rank}",
f"--world_size={world_size}",
],
stderr=subprocess.STDOUT,
)
)
errors = []
for rank, f in enumerate(tasks):
try:
f.result()
except subprocess.CalledProcessError as e:
errors.append((rank, e.output))
if errors:
parts = []
for rank, content in errors:
parts.append(
f"[rank {rank}]: " + content.decode("utf-8", errors="ignore")
)
raise AssertionError("\n\n".join(parts))

View File

@@ -1,145 +0,0 @@
import argparse
import os
from dataclasses import dataclass
import hamcrest
import torch
import mscclpp
@dataclass
class Example:
rank: int
def _test_bootstrap_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.bootstrap_all_gather_int(options.rank + 42),
hamcrest.equal_to(
[
42,
43,
]
),
)
def _test_bootstrap_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
hamcrest.equal_to(
[
b"abc",
b"abcabc",
]
),
)
def _test_bootstrap_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_json({"rank": options.rank}),
hamcrest.equal_to(
[
{"rank": 0},
{"rank": 1},
]
),
)
hamcrest.assert_that(
comm.all_gather_json([options.rank, 42]),
hamcrest.equal_to(
[
[0, 42],
[1, 42],
]
),
)
def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_pickle(Example(rank=options.rank)),
hamcrest.equal_to(
[
Example(rank=0),
Example(rank=1),
]
),
)
comm.connection_setup()
def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm):
rank = options.rank
buf = torch.zeros([options.world_size], dtype=torch.int64)
buf[rank] = 42 + rank
buf = buf.cuda().contiguous()
tag = 0
if rank:
remote_rank = 0
else:
remote_rank = 1
comm.connect(
remote_rank,
tag,
buf.data_ptr(),
buf.element_size() * buf.numel(),
mscclpp.TransportType.P2P,
)
handles = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel())
hamcrest.assert_that(
handles,
hamcrest.has_length(options.world_size - 1),
)
torch.cuda.synchronize()
comm.connection_setup()
comm.launch_proxies()
comm.stop_proxies()
def main():
p = argparse.ArgumentParser()
p.add_argument("--rank", type=int, required=True)
p.add_argument("--world_size", type=int, required=True)
p.add_argument("--port", default=50000)
options = p.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(options.rank)
comm_options = dict(
address=f"127.0.0.1:{options.port}",
rank=options.rank,
world_size=options.world_size,
)
print(f"{comm_options=}", flush=True)
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
# comm.connection_setup()
hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank))
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
try:
_test_bootstrap_allgather_int(options, comm)
_test_bootstrap_allgather_bytes(options, comm)
_test_bootstrap_allgather_json(options, comm)
_test_bootstrap_allgather_pickle(options, comm)
_test_rm(options, comm)
finally:
comm.close()
if __name__ == "__main__":
main()

View File

@@ -1,8 +0,0 @@
#!/bin/bash
set -ex
pip install -e .
cd src
pytest -vs mscclpp

23
python/utils_py.cpp Normal file
View File

@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <mscclpp/utils.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_utils(nb::module_& m) {
nb::class_<Timer>(m, "Timer")
.def(nb::init<int>(), nb::arg("timeout") = -1)
.def("elapsed", &Timer::elapsed)
.def("set", &Timer::set, nb::arg("timeout"))
.def("reset", &Timer::reset)
.def("print", &Timer::print, nb::arg("name"));
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
}

View File

@@ -2,5 +2,5 @@
# Licensed under the MIT license.
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc)
target_sources(mscclpp PRIVATE ${SOURCES})
target_include_directories(mscclpp PRIVATE include)
target_sources(mscclpp_obj PRIVATE ${SOURCES})
target_include_directories(mscclpp_obj PRIVATE include)

View File

@@ -4,6 +4,7 @@
#include <sys/resource.h>
#include <cstring>
#include <mscclpp/config.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/errors.hpp>
#include <sstream>
@@ -12,7 +13,6 @@
#include <vector>
#include "api.h"
#include "config.hpp"
#include "debug.h"
#include "socket.h"
#include "utils_internal.hpp"
@@ -36,13 +36,13 @@ struct ExtInfo {
SocketAddress extAddressListen;
};
MSCCLPP_API_CPP void BaseBootstrap::send(const std::vector<char>& data, int peer, int tag) {
MSCCLPP_API_CPP void Bootstrap::send(const std::vector<char>& data, int peer, int tag) {
size_t size = data.size();
send((void*)&size, sizeof(size_t), peer, tag);
send((void*)data.data(), data.size(), peer, tag + 1);
}
MSCCLPP_API_CPP void BaseBootstrap::recv(std::vector<char>& data, int peer, int tag) {
MSCCLPP_API_CPP void Bootstrap::recv(std::vector<char>& data, int peer, int tag) {
size_t size;
recv((void*)&size, sizeof(size_t), peer, tag);
data.resize(size);
@@ -55,12 +55,12 @@ struct UniqueIdInternal {
};
static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId");
class Bootstrap::Impl {
class TcpBootstrap::Impl {
public:
Impl(int rank, int nRanks);
~Impl();
void initialize(const UniqueId uniqueId);
void initialize(std::string ifIpPortTrio);
void initialize(const UniqueId& uniqueId);
void initialize(const std::string& ifIpPortTrio);
void establishConnections();
UniqueId createUniqueId();
UniqueId getUniqueId() const;
@@ -106,7 +106,7 @@ class Bootstrap::Impl {
void netInit(std::string ipPortPair, std::string interface);
};
Bootstrap::Impl::Impl(int rank, int nRanks)
TcpBootstrap::Impl::Impl(int rank, int nRanks)
: rank_(rank),
nRanks_(nRanks),
netInitialized(false),
@@ -115,13 +115,13 @@ Bootstrap::Impl::Impl(int rank, int nRanks)
abortFlagStorage_(new uint32_t(0)),
abortFlag_(abortFlagStorage_.get()) {}
UniqueId Bootstrap::Impl::getUniqueId() const {
UniqueId TcpBootstrap::Impl::getUniqueId() const {
UniqueId ret;
std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_));
return ret;
}
UniqueId Bootstrap::Impl::createUniqueId() {
UniqueId TcpBootstrap::Impl::createUniqueId() {
netInit("", "");
getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
@@ -129,11 +129,11 @@ UniqueId Bootstrap::Impl::createUniqueId() {
return getUniqueId();
}
int Bootstrap::Impl::getRank() { return rank_; }
int TcpBootstrap::Impl::getRank() { return rank_; }
int Bootstrap::Impl::getNranks() { return nRanks_; }
int TcpBootstrap::Impl::getNranks() { return nRanks_; }
void Bootstrap::Impl::initialize(const UniqueId uniqueId) {
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId) {
netInit("", "");
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
@@ -141,7 +141,7 @@ void Bootstrap::Impl::initialize(const UniqueId uniqueId) {
establishConnections();
}
void Bootstrap::Impl::initialize(std::string ifIpPortTrio) {
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) {
// first check if it is a trio
int nColons = 0;
for (auto c : ifIpPortTrio) {
@@ -170,7 +170,7 @@ void Bootstrap::Impl::initialize(std::string ifIpPortTrio) {
establishConnections();
}
Bootstrap::Impl::~Impl() {
TcpBootstrap::Impl::~Impl() {
if (abortFlag_) {
*abortFlag_ = 1;
}
@@ -179,8 +179,8 @@ Bootstrap::Impl::~Impl() {
}
}
void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
std::vector<SocketAddress>& rankAddressesRoot, int& rank) {
void TcpBootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
std::vector<SocketAddress>& rankAddressesRoot, int& rank) {
ExtInfo info;
SocketAddress zero;
std::memset(&zero, 0, sizeof(SocketAddress));
@@ -209,15 +209,15 @@ void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketA
rank = info.rank;
}
void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
const std::vector<SocketAddress>& rankAddressesRoot) {
void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
const std::vector<SocketAddress>& rankAddressesRoot) {
int next = (peer + 1) % nRanks_;
Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_);
sock.connect();
netSend(&sock, &rankAddresses[next], sizeof(SocketAddress));
}
void Bootstrap::Impl::bootstrapCreateRoot() {
void TcpBootstrap::Impl::bootstrapCreateRoot() {
listenSockRoot_ = std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0);
listenSockRoot_->listen();
uniqueId_.addr = listenSockRoot_->getAddr();
@@ -232,7 +232,7 @@ void Bootstrap::Impl::bootstrapCreateRoot() {
});
}
void Bootstrap::Impl::bootstrapRoot() {
void TcpBootstrap::Impl::bootstrapRoot() {
int numCollected = 0;
std::vector<SocketAddress> rankAddresses(nRanks_, SocketAddress());
// for initial rank <-> root information exchange
@@ -266,7 +266,7 @@ void Bootstrap::Impl::bootstrapRoot() {
TRACE(MSCCLPP_INIT, "DONE");
}
void Bootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
if (netInitialized) return;
if (!ipPortPair.empty()) {
if (interface != "") {
@@ -285,30 +285,30 @@ void Bootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
} else {
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
if (ret <= 0) {
throw Error("Bootstrap : no socket interface found", ErrorCode::InternalError);
throw Error("TcpBootstrap : no socket interface found", ErrorCode::InternalError);
}
}
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
std::sprintf(line, " %s:", netIfName_);
SocketToString(&netIfAddr_, line + strlen(line));
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
INFO(MSCCLPP_INIT, "TcpBootstrap : Using%s", line);
netInitialized = true;
}
#define TIMEOUT(__exp) \
do { \
try { \
__exp; \
} catch (const Error& e) { \
if (e.getErrorCode() == ErrorCode::Timeout) { \
throw Error("Bootstrap connection timeout", ErrorCode::Timeout); \
} \
throw; \
} \
#define TIMEOUT(__exp) \
do { \
try { \
__exp; \
} catch (const Error& e) { \
if (e.getErrorCode() == ErrorCode::Timeout) { \
throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout); \
} \
throw; \
} \
} while (0);
void Bootstrap::Impl::establishConnections() {
void TcpBootstrap::Impl::establishConnections() {
const int64_t connectionTimeoutUs = (int64_t)Config::getInstance()->getBootstrapConnectionTimeoutConfig() * 1000000;
Timer timer;
SocketAddress nextAddr;
@@ -318,7 +318,7 @@ void Bootstrap::Impl::establishConnections() {
auto getLeftTime = [&]() {
int64_t timeout = connectionTimeoutUs - timer.elapsed();
if (timeout <= 0) throw Error("Bootstrap connection timeout", ErrorCode::Timeout);
if (timeout <= 0) throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout);
return timeout;
};
@@ -377,7 +377,7 @@ void Bootstrap::Impl::establishConnections() {
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
}
void Bootstrap::Impl::allGather(void* allData, int size) {
void TcpBootstrap::Impl::allGather(void* allData, int size) {
char* data = static_cast<char*>(allData);
int rank = rank_;
int nRanks = nRanks_;
@@ -401,7 +401,7 @@ void Bootstrap::Impl::allGather(void* allData, int size) {
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
}
std::shared_ptr<Socket> Bootstrap::Impl::getPeerSendSocket(int peer, int tag) {
std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerSendSocket(int peer, int tag) {
auto it = peerSendSockets_.find(std::make_pair(peer, tag));
if (it != peerSendSockets_.end()) {
return it->second;
@@ -414,7 +414,7 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerSendSocket(int peer, int tag) {
return sock;
}
std::shared_ptr<Socket> Bootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
auto it = peerRecvSockets_.find(std::make_pair(peer, tag));
if (it != peerRecvSockets_.end()) {
return it->second;
@@ -432,12 +432,12 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
}
}
void Bootstrap::Impl::netSend(Socket* sock, const void* data, int size) {
void TcpBootstrap::Impl::netSend(Socket* sock, const void* data, int size) {
sock->send(&size, sizeof(int));
sock->send(const_cast<void*>(data), size);
}
void Bootstrap::Impl::netRecv(Socket* sock, void* data, int size) {
void TcpBootstrap::Impl::netRecv(Socket* sock, void* data, int size) {
int recvSize;
sock->recv(&recvSize, sizeof(int));
if (recvSize > size) {
@@ -448,19 +448,19 @@ void Bootstrap::Impl::netRecv(Socket* sock, void* data, int size) {
sock->recv(data, std::min(recvSize, size));
}
void Bootstrap::Impl::send(void* data, int size, int peer, int tag) {
void TcpBootstrap::Impl::send(void* data, int size, int peer, int tag) {
auto sock = getPeerSendSocket(peer, tag);
netSend(sock.get(), data, size);
}
void Bootstrap::Impl::recv(void* data, int size, int peer, int tag) {
void TcpBootstrap::Impl::recv(void* data, int size, int peer, int tag) {
auto sock = getPeerRecvSocket(peer, tag);
netRecv(sock.get(), data, size);
}
void Bootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); }
void TcpBootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); }
void Bootstrap::Impl::close() {
void TcpBootstrap::Impl::close() {
listenSockRoot_.reset(nullptr);
listenSock_.reset(nullptr);
ringRecvSocket_.reset(nullptr);
@@ -469,28 +469,32 @@ void Bootstrap::Impl::close() {
peerRecvSockets_.clear();
}
MSCCLPP_API_CPP Bootstrap::Bootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
MSCCLPP_API_CPP UniqueId Bootstrap::createUniqueId() { return pimpl_->createUniqueId(); }
MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return pimpl_->createUniqueId(); }
MSCCLPP_API_CPP UniqueId Bootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
MSCCLPP_API_CPP int Bootstrap::getRank() { return pimpl_->getRank(); }
MSCCLPP_API_CPP int TcpBootstrap::getRank() { return pimpl_->getRank(); }
MSCCLPP_API_CPP int Bootstrap::getNranks() { return pimpl_->getNranks(); }
MSCCLPP_API_CPP int TcpBootstrap::getNranks() { return pimpl_->getNranks(); }
MSCCLPP_API_CPP void Bootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); }
MSCCLPP_API_CPP void TcpBootstrap::send(void* data, int size, int peer, int tag) {
pimpl_->send(data, size, peer, tag);
}
MSCCLPP_API_CPP void Bootstrap::recv(void* data, int size, int peer, int tag) { pimpl_->recv(data, size, peer, tag); }
MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag) {
pimpl_->recv(data, size, peer, tag);
}
MSCCLPP_API_CPP void Bootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
MSCCLPP_API_CPP void Bootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); }
MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); }
MSCCLPP_API_CPP void Bootstrap::initialize(std::string ipPortPair) { pimpl_->initialize(ipPortPair); }
MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair) { pimpl_->initialize(ipPortPair); }
MSCCLPP_API_CPP void Bootstrap::barrier() { pimpl_->barrier(); }
MSCCLPP_API_CPP void TcpBootstrap::barrier() { pimpl_->barrier(); }
MSCCLPP_API_CPP Bootstrap::~Bootstrap() { pimpl_->close(); }
MSCCLPP_API_CPP TcpBootstrap::~TcpBootstrap() { pimpl_->close(); }
} // namespace mscclpp

View File

@@ -13,11 +13,11 @@
#include <string.h>
#include <unistd.h>
#include <mscclpp/config.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/utils.hpp>
#include <sstream>
#include "config.hpp"
#include "debug.h"
#include "utils_internal.hpp"

View File

@@ -15,7 +15,7 @@
namespace mscclpp {
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap) {
Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap) : bootstrap_(bootstrap) {
rankToHash_.resize(bootstrap->getNranks());
auto hostHash = getHostHash();
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
@@ -47,10 +47,10 @@ cudaStream_t Communicator::Impl::getIpcStream() { return ipcStream_; }
MSCCLPP_API_CPP Communicator::~Communicator() = default;
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap)
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<Bootstrap> bootstrap)
: pimpl(std::make_unique<Impl>(bootstrap)) {}
MSCCLPP_API_CPP std::shared_ptr<BaseBootstrap> Communicator::bootstrapper() { return pimpl->bootstrap_; }
MSCCLPP_API_CPP std::shared_ptr<Bootstrap> Communicator::bootstrap() { return pimpl->bootstrap_; }
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) {
return RegisteredMemory(
@@ -61,7 +61,7 @@ struct MemorySender : public Setuppable {
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override {
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
bootstrap->send(memory_.serialize(), remoteRank_, tag_);
}
@@ -77,7 +77,7 @@ MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, in
struct MemoryReceiver : public Setuppable {
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override {
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override {
std::vector<char> data;
bootstrap->recv(data, remoteRank_, tag_);
memoryPromise_.set_value(RegisteredMemory::deserialize(data));

View File

@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include "config.hpp"
#include <mscclpp/config.hpp>
namespace mscclpp {
Config Config::instance_;

View File

@@ -171,7 +171,7 @@ void IBConnection::flush() {
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}
void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
void IBConnection::beginSetup(std::shared_ptr<Bootstrap> bootstrap) {
std::vector<char> ibQpTransport;
std::copy_n(reinterpret_cast<char*>(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport));
std::copy_n(reinterpret_cast<char*>(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport));
@@ -179,7 +179,7 @@ void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
}
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
void IBConnection::endSetup(std::shared_ptr<Bootstrap> bootstrap) {
std::vector<char> ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport));
bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());

View File

@@ -82,9 +82,9 @@ const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transpo
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
void Setuppable::beginSetup(std::shared_ptr<BaseBootstrap>) {}
void Setuppable::beginSetup(std::shared_ptr<Bootstrap>) {}
void Setuppable::endSetup(std::shared_ptr<BaseBootstrap>) {}
void Setuppable::endSetup(std::shared_ptr<Bootstrap>) {}
} // namespace mscclpp

View File

@@ -22,10 +22,10 @@ struct Communicator::Impl {
std::vector<std::shared_ptr<Setuppable>> toSetup_;
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
cudaStream_t ipcStream_;
std::shared_ptr<BaseBootstrap> bootstrap_;
std::shared_ptr<Bootstrap> bootstrap_;
std::vector<uint64_t> rankToHash_;
Impl(std::shared_ptr<BaseBootstrap> bootstrap);
Impl(std::shared_ptr<Bootstrap> bootstrap);
~Impl();

View File

@@ -71,9 +71,9 @@ class IBConnection : public ConnectionBase {
void flush() override;
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override;
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override;
};
} // namespace mscclpp

View File

@@ -69,12 +69,12 @@ MSCCLPP_API_CPP SmDevice2DeviceSemaphore::SmDevice2DeviceSemaphore(Communicator&
remoteInboundSemaphoreIdsRegMem_ =
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
INFO(MSCCLPP_INIT, "Creating a direct semaphore for CudaIPC transport from %d to %d",
communicator.bootstrapper()->getRank(), connection->remoteRank());
communicator.bootstrap()->getRank(), connection->remoteRank());
isRemoteInboundSemaphoreIdSet_ = true;
} else if (AllIBTransports.has(connection->transport())) {
// We don't need to really with any of the IB transports, since the values will be local
INFO(MSCCLPP_INIT, "Creating a direct semaphore for IB transport from %d to %d",
communicator.bootstrapper()->getRank(), connection->remoteRank());
communicator.bootstrap()->getRank(), connection->remoteRank());
isRemoteInboundSemaphoreIdSet_ = false;
}
}

View File

@@ -391,9 +391,9 @@ int main(int argc, const char* argv[]) {
try {
if (rank == 0) printf("Initializing MSCCL++\n");
auto bootstrapper = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
bootstrapper->initialize(ip_port);
mscclpp::Communicator comm(bootstrapper);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
bootstrap->initialize(ip_port);
mscclpp::Communicator comm(bootstrap);
mscclpp::ProxyService channelService(comm);
if (rank == 0) printf("Initializing data for allgather test\n");
@@ -422,19 +422,19 @@ int main(int argc, const char* argv[]) {
}
int tmp[16];
// A simple barrier
bootstrapper->allGather(tmp, sizeof(int));
bootstrap->allGather(tmp, sizeof(int));
if (rank == 0) printf("Successfully checked the correctness\n");
// Perf test
int iterwithoutcudagraph = 10;
if (rank == 0) printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
CUDACHECK(cudaStreamSynchronize(stream));
bootstrapper->allGather(tmp, sizeof(int));
bootstrap->allGather(tmp, sizeof(int));
for (int i = 0; i < iterwithoutcudagraph; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
CUDACHECK(cudaStreamSynchronize(stream));
bootstrapper->allGather(tmp, sizeof(int));
bootstrap->allGather(tmp, sizeof(int));
// cudaGraph Capture
int cudagraphiter = 10;
@@ -462,7 +462,7 @@ int main(int argc, const char* argv[]) {
if (rank == 0)
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
cudagraphiter);
bootstrapper->allGather(tmp, sizeof(int));
bootstrap->allGather(tmp, sizeof(int));
double t0, t1, ms, time_in_us;
t0 = getTime();
for (int i = 0; i < cudagraphlaunch; ++i) {
@@ -475,7 +475,7 @@ int main(int argc, const char* argv[]) {
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
bootstrapper->allGather(tmp, sizeof(int));
bootstrap->allGather(tmp, sizeof(int));
if (rank == 0) printf("Stopping MSCCL++ proxy threads\n");
channelService.stopProxy();

View File

@@ -236,7 +236,7 @@ int main(int argc, char* argv[]) {
CUCHECK(cudaSetDevice(cudaNum));
if (rank == 0) printf("Initializing MSCCL++\n");
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
mscclpp::UniqueId uniqueId;
if (rank == 0) uniqueId = bootstrap->createUniqueId();
MPI_Bcast(&uniqueId, sizeof(uniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);

View File

@@ -3,10 +3,11 @@
#include <mpi.h>
#include "config.hpp"
#include <mscclpp/config.hpp>
#include "mp_unit_tests.hpp"
void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
std::vector<int> tmp(bootstrap->getNranks(), 0);
tmp[bootstrap->getRank()] = bootstrap->getRank() + 1;
bootstrap->allGather(tmp.data(), sizeof(int));
@@ -15,9 +16,9 @@ void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::BaseBootstra
}
}
void BootstrapTest::bootstrapTestBarrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) { bootstrap->barrier(); }
void BootstrapTest::bootstrapTestBarrier(std::shared_ptr<mscclpp::Bootstrap> bootstrap) { bootstrap->barrier(); }
void BootstrapTest::bootstrapTestSendRecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
void BootstrapTest::bootstrapTestSendRecv(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (bootstrap->getRank() == i) continue;
int msg1 = (bootstrap->getRank() + 1) * 3;
@@ -43,14 +44,14 @@ void BootstrapTest::bootstrapTestSendRecv(std::shared_ptr<mscclpp::BaseBootstrap
}
}
void BootstrapTest::bootstrapTestAll(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
void BootstrapTest::bootstrapTestAll(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
bootstrapTestAllGather(bootstrap);
bootstrapTestBarrier(bootstrap);
bootstrapTestSendRecv(bootstrap);
}
TEST_F(BootstrapTest, WithId) {
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
@@ -59,14 +60,14 @@ TEST_F(BootstrapTest, WithId) {
}
TEST_F(BootstrapTest, WithIpPortPair) {
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
bootstrap->initialize(gEnv->args["ip_port"]);
bootstrapTestAll(bootstrap);
}
TEST_F(BootstrapTest, ResumeWithId) {
for (int i = 0; i < 5; ++i) {
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
@@ -76,13 +77,13 @@ TEST_F(BootstrapTest, ResumeWithId) {
TEST_F(BootstrapTest, ResumeWithIpPortPair) {
for (int i = 0; i < 5; ++i) {
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
bootstrap->initialize(gEnv->args["ip_port"]);
}
}
TEST_F(BootstrapTest, ExitBeforeConnect) {
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
bootstrap->createUniqueId();
}
@@ -94,7 +95,7 @@ TEST_F(BootstrapTest, TimeoutWithId) {
mscclpp::Timer timer;
// All ranks initialize a bootstrap with their own id (will hang)
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
mscclpp::UniqueId id = bootstrap->createUniqueId();
try {
@@ -108,9 +109,9 @@ TEST_F(BootstrapTest, TimeoutWithId) {
ASSERT_LT(timer.elapsed(), 1100000);
}
class MPIBootstrap : public mscclpp::BaseBootstrap {
class MPIBootstrap : public mscclpp::Bootstrap {
public:
MPIBootstrap() : BaseBootstrap() {}
MPIBootstrap() : Bootstrap() {}
int getRank() override {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);

View File

@@ -19,10 +19,10 @@ void CommunicatorTestBase::SetUp() {
ibTransport = ibIdToTransport(rankToLocalRank(gEnv->rank));
MSCCLPP_CUDATHROW(cudaSetDevice(rankToLocalRank(gEnv->rank)));
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
mscclpp::UniqueId id;
if (gEnv->rank < numRanksToUse) {
bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, numRanksToUse);
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, numRanksToUse);
if (gEnv->rank == 0) id = bootstrap->createUniqueId();
}
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
@@ -63,14 +63,14 @@ void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, mscc
localMemory = communicator->registerMemory(buff, buffSize, transport);
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemories;
for (int remoteRank : remoteRanks) {
if (remoteRank != communicator->bootstrapper()->getRank()) {
if (remoteRank != communicator->bootstrap()->getRank()) {
communicator->sendMemoryOnSetup(localMemory, remoteRank, tag);
futureRemoteMemories[remoteRank] = communicator->recvMemoryOnSetup(remoteRank, tag);
}
}
communicator->setup();
for (int remoteRank : remoteRanks) {
if (remoteRank != communicator->bootstrapper()->getRank()) {
if (remoteRank != communicator->bootstrap()->getRank()) {
remoteMemories[remoteRank] = futureRemoteMemories[remoteRank].get();
}
}
@@ -166,10 +166,10 @@ TEST_F(CommunicatorTest, BasicWrite) {
if (gEnv->rank >= numRanksToUse) return;
deviceBufferInit();
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
// polling until it becomes ready
bool ready = false;
@@ -181,7 +181,7 @@ TEST_F(CommunicatorTest, BasicWrite) {
FAIL() << "Polling is stuck.";
}
} while (!ready);
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
}
__global__ void kernelWaitSemaphores(mscclpp::Host2DeviceSemaphore::DeviceHandle* deviceSemaphores, int rank,
@@ -201,10 +201,10 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2DeviceSemaphore>(*communicator.get(), conn)});
}
communicator->setup();
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
deviceBufferInit();
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
auto deviceSemaphoreHandles = mscclpp::allocSharedCuda<mscclpp::Host2DeviceSemaphore::DeviceHandle>(gEnv->worldSize);
for (int i = 0; i < gEnv->worldSize; i++) {
@@ -214,7 +214,7 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
1, cudaMemcpyHostToDevice);
}
}
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
@@ -228,7 +228,7 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
ASSERT_TRUE(testWriteCorrectness());
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
}
TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
@@ -242,10 +242,10 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2HostSemaphore>(*communicator.get(), conn)});
}
communicator->setup();
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
deviceBufferInit();
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
@@ -274,5 +274,5 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
}
ASSERT_TRUE(testWriteCorrectness());
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
}

View File

@@ -24,7 +24,7 @@ void IbPeerToPeerTest::SetUp() {
if (gEnv->rank < 2) {
// This test needs only two ranks
bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, 2);
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, 2);
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
}
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);

View File

@@ -42,13 +42,13 @@ class MultiProcessTest : public ::testing::Test {
class BootstrapTest : public MultiProcessTest {
protected:
void bootstrapTestAllGather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
void bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
void bootstrapTestBarrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
void bootstrapTestBarrier(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
void bootstrapTestSendRecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
void bootstrapTestSendRecv(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
void bootstrapTestAll(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
void bootstrapTestAll(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
// Each test case should finish within 30 seconds.
mscclpp::Timer bootstrapTestTimer{30};
@@ -76,7 +76,7 @@ class IbPeerToPeerTest : public IbTestBase {
void stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled,
unsigned int immData);
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
std::shared_ptr<mscclpp::IbCtx> ibCtx;
mscclpp::IbQp* qp;
const mscclpp::IbMr* mr;

View File

@@ -17,8 +17,8 @@ void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels,
bool useIbOnly, void* sendBuff, size_t sendBuffBytes,
void* recvBuff, size_t recvBuffBytes) {
const int rank = communicator->bootstrapper()->getRank();
const int worldSize = communicator->bootstrapper()->getNranks();
const int rank = communicator->bootstrap()->getRank();
const int worldSize = communicator->bootstrap()->getNranks();
const bool isInPlace = (recvBuff == nullptr);
mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport);
@@ -274,7 +274,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
EXPECT_EQ(*ret, 0);
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
channelService->stopProxy();
}
@@ -312,7 +312,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
<<<1, 1024>>>(buff.get(), putPacketBuffer.get(), getPacketBuffer.get(), gEnv->rank, 2, nTries, nullptr);
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
// Measure latency
mscclpp::Timer timer;
@@ -320,7 +320,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
<<<1, 1024>>>(buff.get(), putPacketBuffer.get(), getPacketBuffer.get(), gEnv->rank, 2, nTries, nullptr);
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrapper()->barrier();
communicator->bootstrap()->barrier();
if (gEnv->rank == 0) {
std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)nTries << " us/iter\n";

View File

@@ -17,8 +17,8 @@ void SmChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
void SmChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes) {
const int rank = communicator->bootstrapper()->getRank();
const int worldSize = communicator->bootstrapper()->getNranks();
const int rank = communicator->bootstrap()->getRank();
const int worldSize = communicator->bootstrap()->getNranks();
const bool isInPlace = (outputBuff == nullptr);
mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc | ibTransport;

View File

@@ -225,7 +225,7 @@ double BaseTestEngine::benchTime() {
return deltaSec;
}
void BaseTestEngine::barrier() { this->comm_->bootstrapper()->barrier(); }
void BaseTestEngine::barrier() { this->comm_->bootstrap()->barrier(); }
void BaseTestEngine::runTest() {
// warm-up for large size
@@ -326,7 +326,7 @@ void BaseTestEngine::runTest() {
}
void BaseTestEngine::bootstrap() {
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(args_.rank, args_.totalRanks);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(args_.rank, args_.totalRanks);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);

View File

@@ -9,25 +9,25 @@
class LocalCommunicatorTest : public ::testing::Test {
protected:
void SetUp() override {
bootstrap = std::make_shared<mscclpp::Bootstrap>(0, 1);
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(0, 1);
comm = std::make_shared<mscclpp::Communicator>(bootstrap);
}
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
std::shared_ptr<mscclpp::Communicator> comm;
};
class MockSetuppable : public mscclpp::Setuppable {
public:
MOCK_METHOD(void, beginSetup, (std::shared_ptr<mscclpp::BaseBootstrap> bootstrap), (override));
MOCK_METHOD(void, endSetup, (std::shared_ptr<mscclpp::BaseBootstrap> bootstrap), (override));
MOCK_METHOD(void, beginSetup, (std::shared_ptr<mscclpp::Bootstrap> bootstrap), (override));
MOCK_METHOD(void, endSetup, (std::shared_ptr<mscclpp::Bootstrap> bootstrap), (override));
};
TEST_F(LocalCommunicatorTest, OnSetup) {
auto mockSetuppable = std::make_shared<MockSetuppable>();
comm->onSetup(mockSetuppable);
EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast<mscclpp::BaseBootstrap>(bootstrap)));
EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast<mscclpp::BaseBootstrap>(bootstrap)));
EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
comm->setup();
}