mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Python bindings (#125)
Co-authored-by: Olli Saarikivi <olsaarik@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com> Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
5
.black
5
.black
@@ -1,5 +0,0 @@
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py38']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = 'python/'
|
||||
6
.github/workflows/codeql.yml
vendored
6
.github/workflows/codeql.yml
vendored
@@ -44,10 +44,14 @@ jobs:
|
||||
tar xzf /tmp/cmake-3.26.4-linux-x86_64.tar.gz -C /tmp
|
||||
sudo ln -s /tmp/cmake-3.26.4-linux-x86_64/bin/cmake /usr/bin/cmake
|
||||
|
||||
- name: Dubious ownership exception
|
||||
run: |
|
||||
git config --global --add safe.directory /__w/mscclpp/mscclpp
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v2
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
category: "/language:${{matrix.language}}/cuda-version:${{matrix.cuda-version}}"
|
||||
|
||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@@ -20,10 +20,8 @@ jobs:
|
||||
|
||||
- name: Run cpplint
|
||||
run: |
|
||||
CPPSOURCES=$(find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
|
||||
PYTHONCPPSOURCES=$(find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
|
||||
CPPSOURCES=$(find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*")
|
||||
clang-format -style=file --verbose --Werror --dry-run ${CPPSOURCES}
|
||||
clang-format --dry-run ${PYTHONCPPSOURCES}
|
||||
|
||||
pylint:
|
||||
runs-on: ubuntu-20.04
|
||||
@@ -45,7 +43,7 @@ jobs:
|
||||
with:
|
||||
black: true
|
||||
black_auto_fix: false
|
||||
black_args: "--config .black --check"
|
||||
black_args: "--config pyproject.toml --check"
|
||||
|
||||
spelling:
|
||||
runs-on: ubuntu-20.04
|
||||
|
||||
@@ -8,7 +8,7 @@ set(MSCCLPP_PATCH "0")
|
||||
set(MSCCLPP_SOVERSION ${MSCCLPP_MAJOR})
|
||||
set(MSCCLPP_VERSION "${MSCCLPP_MAJOR}.${MSCCLPP_MINOR}.${MSCCLPP_PATCH}")
|
||||
|
||||
cmake_minimum_required(VERSION 3.26)
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
project(mscclpp LANGUAGES CUDA CXX)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
@@ -23,6 +23,8 @@ include(${PROJECT_SOURCE_DIR}/cmake/AddClangFormatTargets.cmake)
|
||||
# Options
|
||||
option(ENABLE_TRACE "Enable tracing" OFF)
|
||||
option(USE_NPKIT "Use NPKIT" ON)
|
||||
option(BUILD_TESTS "Build tests" ON)
|
||||
option(BUILD_PYTHON_BINDINGS "Build Python bindings" ON)
|
||||
option(ALLOW_GDRCOPY "Use GDRCopy, if available" OFF)
|
||||
|
||||
# Find CUDAToolkit. Set CUDA flags based on the detected CUDA version
|
||||
@@ -51,25 +53,48 @@ if(ALLOW_GDRCOPY)
|
||||
find_package(GDRCopy)
|
||||
endif()
|
||||
|
||||
# libmscclpp
|
||||
add_library(mscclpp SHARED)
|
||||
target_include_directories(mscclpp
|
||||
add_library(mscclpp_obj OBJECT)
|
||||
target_include_directories(mscclpp_obj
|
||||
PRIVATE
|
||||
${CUDAToolkit_INCLUDE_DIRS}
|
||||
${IBVERBS_INCLUDE_DIRS}
|
||||
${NUMA_INCLUDE_DIRS}
|
||||
${GDRCOPY_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp PRIVATE ${CUDA_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} ${GDRCOPY_LIBRARIES})
|
||||
set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
|
||||
target_link_libraries(mscclpp_obj PRIVATE ${CUDA_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} ${GDRCOPY_LIBRARIES})
|
||||
set_target_properties(mscclpp_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
|
||||
if(ENABLE_TRACE)
|
||||
target_compile_definitions(mscclpp PRIVATE ENABLE_TRACE)
|
||||
target_compile_definitions(mscclpp_obj PRIVATE ENABLE_TRACE)
|
||||
endif()
|
||||
if(USE_NPKIT)
|
||||
target_compile_definitions(mscclpp PRIVATE ENABLE_NPKIT)
|
||||
target_compile_definitions(mscclpp_obj PRIVATE ENABLE_NPKIT)
|
||||
endif()
|
||||
if(ALLOW_GDRCOPY AND GDRCOPY_FOUND)
|
||||
target_compile_definitions(mscclpp_obj PRIVATE MSCCLPP_USE_GDRCOPY)
|
||||
target_link_libraries(mscclpp_obj PRIVATE MSCCLPP::gdrcopy)
|
||||
endif()
|
||||
|
||||
# libmscclpp
|
||||
add_library(mscclpp SHARED)
|
||||
target_link_libraries(mscclpp PUBLIC mscclpp_obj)
|
||||
set_target_properties(mscclpp PROPERTIES VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
|
||||
add_library(mscclpp_static STATIC)
|
||||
target_link_libraries(mscclpp_static PUBLIC mscclpp_obj)
|
||||
set_target_properties(mscclpp_static PROPERTIES VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(src)
|
||||
install(TARGETS mscclpp LIBRARY DESTINATION lib)
|
||||
install(TARGETS mscclpp mscclpp_static
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib)
|
||||
|
||||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION include FILES_MATCHING PATTERN "*.hpp")
|
||||
|
||||
# Tests
|
||||
add_subdirectory(test)
|
||||
if (BUILD_TESTS)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
# Python bindings
|
||||
if(BUILD_PYTHON_BINDINGS)
|
||||
add_subdirectory(python)
|
||||
endif()
|
||||
|
||||
@@ -61,7 +61,7 @@ Some in-kernel communication interfaces of MSCCL++ send requests (called trigger
|
||||
|
||||
```cpp
|
||||
// Bootstrap: initialize control-plane connections between all ranks
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
|
||||
// Create a communicator for connection setup
|
||||
mscclpp::Communicator comm(bootstrap);
|
||||
// Setup connections here using `comm`
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
find_program(CLANG_FORMAT clang-format)
|
||||
if(CLANG_FORMAT)
|
||||
message(STATUS "Found clang-format: ${CLANG_FORMAT}")
|
||||
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/test)
|
||||
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test)
|
||||
add_custom_target(check-format ALL
|
||||
COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
|
||||
)
|
||||
|
||||
@@ -20,10 +20,14 @@ RUN rm -rf build && \
|
||||
mkdir build && \
|
||||
cd build && \
|
||||
${CMAKE_HOME}/bin/cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${MSCCLPP_HOME} .. && \
|
||||
make -j mscclpp && \
|
||||
make -j mscclpp mscclpp_static && \
|
||||
make install/fast && \
|
||||
strip ${MSCCLPP_HOME}/lib/libmscclpp.so.[0-9]*.[0-9]*.[0-9]*
|
||||
|
||||
# Install MSCCL++ Python bindings
|
||||
WORKDIR ${MSCCLPP_SRC_DIR}
|
||||
RUN python3.8 -m pip install .
|
||||
|
||||
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${MSCCLPP_HOME}/lib"
|
||||
RUN echo LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" >> /etc/environment
|
||||
|
||||
|
||||
@@ -2,5 +2,7 @@
|
||||
# Licensed under the MIT license.
|
||||
|
||||
file(GLOB_RECURSE HEADERS CONFIGURE_DEPENDS *.hpp)
|
||||
target_sources(mscclpp PUBLIC FILE_SET HEADERS FILES ${HEADERS})
|
||||
install(TARGETS mscclpp FILE_SET HEADERS)
|
||||
target_sources(mscclpp_obj PUBLIC FILE_SET HEADERS FILES ${HEADERS})
|
||||
if(NOT SKBUILD)
|
||||
install(TARGETS mscclpp FILE_SET HEADERS)
|
||||
endif()
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#define MSCCLPP_PATCH 0
|
||||
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
|
||||
|
||||
#include <array>
|
||||
#include <bitset>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
@@ -21,15 +22,13 @@ namespace mscclpp {
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
|
||||
/// Unique ID for a process. This is a MSCCLPP_UNIQUE_ID_BYTES byte array that uniquely identifies a process.
|
||||
struct UniqueId {
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
};
|
||||
using UniqueId = std::array<uint8_t, MSCCLPP_UNIQUE_ID_BYTES>;
|
||||
|
||||
/// Base class for bootstrappers.
|
||||
class BaseBootstrap {
|
||||
/// Base class for bootstraps.
|
||||
class Bootstrap {
|
||||
public:
|
||||
BaseBootstrap(){};
|
||||
virtual ~BaseBootstrap() = default;
|
||||
Bootstrap(){};
|
||||
virtual ~Bootstrap() = default;
|
||||
virtual int getRank() = 0;
|
||||
virtual int getNranks() = 0;
|
||||
virtual void send(void* data, int size, int peer, int tag) = 0;
|
||||
@@ -41,32 +40,32 @@ class BaseBootstrap {
|
||||
void recv(std::vector<char>& data, int peer, int tag);
|
||||
};
|
||||
|
||||
/// A native implementation of the bootstrapper.
|
||||
class Bootstrap : public BaseBootstrap {
|
||||
/// A native implementation of the bootstrap using TCP sockets.
|
||||
class TcpBootstrap : public Bootstrap {
|
||||
public:
|
||||
/// Construct a Bootstrap.
|
||||
/// Constructor.
|
||||
/// @param rank The rank of the process.
|
||||
/// @param nRanks The total number of ranks.
|
||||
Bootstrap(int rank, int nRanks);
|
||||
TcpBootstrap(int rank, int nRanks);
|
||||
|
||||
/// Destroy the Bootstrap.
|
||||
~Bootstrap();
|
||||
/// Destructor.
|
||||
~TcpBootstrap();
|
||||
|
||||
/// Create a random unique ID and store it in the Bootstrap.
|
||||
/// Create a random unique ID and store it in the @ref TcpBootstrap.
|
||||
/// @return The created unique ID.
|
||||
UniqueId createUniqueId();
|
||||
|
||||
/// Return the unique ID stored in the Bootstrap.
|
||||
/// @return The unique ID stored in the Bootstrap.
|
||||
/// Return the unique ID stored in the @ref TcpBootstrap.
|
||||
/// @return The unique ID stored in the @ref TcpBootstrap.
|
||||
UniqueId getUniqueId() const;
|
||||
|
||||
/// Initialize the Bootstrap with a given unique ID.
|
||||
/// @param uniqueId The unique ID to initialize the Bootstrap with.
|
||||
/// Initialize the @ref TcpBootstrap with a given unique ID.
|
||||
/// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with.
|
||||
void initialize(UniqueId uniqueId);
|
||||
|
||||
/// Initialize the Bootstrap with a string formatted as "ip:port" or "interface:ip:port".
|
||||
/// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port".
|
||||
/// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port".
|
||||
void initialize(std::string ifIpPortTrio);
|
||||
void initialize(const std::string& ifIpPortTrio);
|
||||
|
||||
/// Return the rank of the process.
|
||||
int getRank() override;
|
||||
@@ -109,9 +108,9 @@ class Bootstrap : public BaseBootstrap {
|
||||
void barrier() override;
|
||||
|
||||
private:
|
||||
/// Implementation class for Bootstrap.
|
||||
/// Implementation class for @ref TcpBootstrap.
|
||||
class Impl;
|
||||
/// Pointer to the implementation class for Bootstrap.
|
||||
/// Pointer to the implementation class for @ref TcpBootstrap.
|
||||
std::unique_ptr<Impl> pimpl_;
|
||||
};
|
||||
|
||||
@@ -421,13 +420,13 @@ struct Setuppable {
|
||||
/// being set up within the same @ref Communicator::setup() call.
|
||||
///
|
||||
/// @param bootstrap A shared pointer to the bootstrap implementation.
|
||||
virtual void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap);
|
||||
virtual void beginSetup(std::shared_ptr<Bootstrap> bootstrap);
|
||||
|
||||
/// Called inside @ref Communicator::setup() after all calls to @ref beginSetup() of all @ref Setuppable objects that
|
||||
/// are being set up within the same @ref Communicator::setup() call.
|
||||
///
|
||||
/// @param bootstrap A shared pointer to the bootstrap implementation.
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap> bootstrap);
|
||||
virtual void endSetup(std::shared_ptr<Bootstrap> bootstrap);
|
||||
};
|
||||
|
||||
/// A non-blocking future that can be used to check if a value is ready and retrieve it.
|
||||
@@ -484,16 +483,16 @@ class Communicator {
|
||||
public:
|
||||
/// Initializes the communicator with a given bootstrap implementation.
|
||||
///
|
||||
/// @param bootstrap An implementation of the BaseBootstrap that the communicator will use.
|
||||
Communicator(std::shared_ptr<BaseBootstrap> bootstrap);
|
||||
/// @param bootstrap An implementation of the Bootstrap that the communicator will use.
|
||||
Communicator(std::shared_ptr<Bootstrap> bootstrap);
|
||||
|
||||
/// Destroy the communicator.
|
||||
~Communicator();
|
||||
|
||||
/// Returns the bootstrapper held by this communicator.
|
||||
/// Returns the bootstrap held by this communicator.
|
||||
///
|
||||
/// @return std::shared_ptr<BaseBootstrap> The bootstrapper held by this communicator.
|
||||
std::shared_ptr<BaseBootstrap> bootstrapper();
|
||||
/// @return std::shared_ptr<Bootstrap> The bootstrap held by this communicator.
|
||||
std::shared_ptr<Bootstrap> bootstrap();
|
||||
|
||||
/// Register a region of GPU memory for use in this communicator.
|
||||
///
|
||||
|
||||
@@ -134,7 +134,7 @@ union ChannelTrigger {
|
||||
struct ProxyChannel {
|
||||
ProxyChannel() = default;
|
||||
|
||||
ProxyChannel(SemaphoreId SemaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, DeviceProxyFifo fifo);
|
||||
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, DeviceProxyFifo fifo);
|
||||
|
||||
ProxyChannel(const ProxyChannel& other) = default;
|
||||
|
||||
|
||||
20
pyproject.toml
Normal file
20
pyproject.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[build-system]
|
||||
requires = ["scikit-build-core"]
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "mscclpp"
|
||||
version = "0.2.0"
|
||||
|
||||
[tool.scikit-build]
|
||||
cmake.minimum-version = "3.25.0"
|
||||
build-dir = "build/{wheel_tag}"
|
||||
|
||||
[tool.scikit-build.cmake.define]
|
||||
BUILD_PYTHON_BINDINGS = "ON"
|
||||
BUILD_TESTS = "OFF"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py38']
|
||||
include = '\.pyi?$'
|
||||
2
python/.gitignore
vendored
2
python/.gitignore
vendored
@@ -1,2 +0,0 @@
|
||||
.*.swp
|
||||
.venv/
|
||||
@@ -1,61 +1,16 @@
|
||||
project(mscclpp)
|
||||
cmake_minimum_required(VERSION 3.18...3.22)
|
||||
find_package(Python 3.9 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
||||
endif ()
|
||||
|
||||
# Create CMake targets for all Python components needed by nanobind
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.26)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module Development.SABIModule REQUIRED)
|
||||
else ()
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
endif ()
|
||||
|
||||
# Detect the installed nanobind package and import it into CMake
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -c "import nanobind; print(nanobind.cmake_dir())"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
set(CUDA_DIR "/usr/local/cuda")
|
||||
|
||||
set(MSCCLPP_DIR ${CMAKE_CURRENT_LIST_DIR}/../build)
|
||||
|
||||
nanobind_add_module(
|
||||
_py_mscclpp
|
||||
NOSTRIP
|
||||
NB_STATIC
|
||||
src/_py_mscclpp.cpp
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
_py_mscclpp
|
||||
PUBLIC
|
||||
${CUDA_DIR}/include
|
||||
${MSCCLPP_DIR}/include
|
||||
)
|
||||
target_link_directories(
|
||||
_py_mscclpp
|
||||
PUBLIC
|
||||
${CUDA_DIR}/lib
|
||||
${MSCCLPP_DIR}/lib
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
_py_mscclpp
|
||||
PUBLIC
|
||||
mscclpp
|
||||
)
|
||||
|
||||
add_custom_target(build-package ALL DEPENDS _py_mscclpp)
|
||||
|
||||
add_custom_command(
|
||||
TARGET build-package POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${MSCCLPP_DIR}/lib/libmscclpp.so
|
||||
${CMAKE_CURRENT_BINARY_DIR})
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0)
|
||||
FetchContent_MakeAvailable(nanobind)
|
||||
|
||||
nanobind_add_module(mscclpp_py core_py.cpp error_py.cpp proxy_channel_py.cpp fifo_py.cpp semaphore_py.cpp
|
||||
config_py.cpp utils_py.cpp)
|
||||
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME mscclpp)
|
||||
target_link_libraries(mscclpp_py PRIVATE mscclpp_static)
|
||||
target_include_directories(mscclpp_py PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
if (SKBUILD)
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION ${SKBUILD_PLATLIB_DIR})
|
||||
endif()
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
|
||||
test:
|
||||
./test.sh
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
# Python bindings
|
||||
|
||||
Test instructions:
|
||||
|
||||
* Compile the `libmscclpp.so` library.
|
||||
* Install `cmake` verion >= 3.18
|
||||
* setup a python virtual env
|
||||
* `pip install -r dev-requirements.txt`
|
||||
* `./tesh.sh`
|
||||
|
||||
## Run CI:
|
||||
|
||||
```bash
|
||||
./ci.sh
|
||||
```
|
||||
|
||||
## Build a wheel:
|
||||
|
||||
Setup dev environment, then:
|
||||
|
||||
```bash
|
||||
python setup.py bdist_wheel
|
||||
```
|
||||
|
||||
## Installing mpi and numa libs.
|
||||
|
||||
```
|
||||
## numctl
|
||||
apt install -y numactl libnuma-dev libnuma1
|
||||
```
|
||||
24
python/ci.sh
24
python/ci.sh
@@ -1,24 +0,0 @@
|
||||
#!/bin/bash
|
||||
# CI hook script.
|
||||
|
||||
set -ex
|
||||
|
||||
# CD to this directory.
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd $SCRIPT_DIR
|
||||
|
||||
# clean env
|
||||
rm -rf .venv build
|
||||
|
||||
# setup a python virtual env
|
||||
python -m venv .venv
|
||||
|
||||
# activate the virtual env
|
||||
source .venv/bin/activate
|
||||
|
||||
# install venv deps.
|
||||
pip install -r dev-requirements.txt
|
||||
|
||||
# run the build and test.
|
||||
./test.sh
|
||||
|
||||
16
python/config_py.cpp
Normal file
16
python/config_py.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include <mscclpp/config.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_config(nb::module_& m) {
|
||||
nb::class_<Config>(m, "Config")
|
||||
.def_static("get_instance", &Config::getInstance, nb::rv_policy::reference)
|
||||
.def("get_bootstrap_connection_timeout_config", &Config::getBootstrapConnectionTimeoutConfig)
|
||||
.def("set_bootstrap_connection_timeout_config", &Config::setBootstrapConnectionTimeoutConfig);
|
||||
}
|
||||
153
python/core_py.cpp
Normal file
153
python/core_py.cpp
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/operators.h>
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
extern void register_error(nb::module_& m);
|
||||
extern void register_proxy_channel(nb::module_& m);
|
||||
extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_config(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
|
||||
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
|
||||
.def("ready", &NonblockingFuture<T>::ready)
|
||||
.def("get", &NonblockingFuture<T>::get);
|
||||
}
|
||||
|
||||
void register_core(nb::module_& m) {
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def(
|
||||
"send",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
void* data = reinterpret_cast<void*>(ptr);
|
||||
self->send(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def(
|
||||
"recv",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
void* data = reinterpret_cast<void*>(ptr);
|
||||
self->recv(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
|
||||
.def("barrier", &Bootstrap::barrier)
|
||||
.def("send", (void (Bootstrap::*)(const std::vector<char>&, int, int)) & Bootstrap::send, nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"))
|
||||
.def("recv", (void (Bootstrap::*)(std::vector<char>&, int, int)) & Bootstrap::recv, nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"));
|
||||
|
||||
nb::class_<UniqueId>(m, "UniqueId");
|
||||
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
|
||||
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
|
||||
.def_static(
|
||||
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
|
||||
nb::arg("nRanks"))
|
||||
.def("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
||||
.def("initialize", (void (TcpBootstrap::*)(UniqueId)) & TcpBootstrap::initialize, nb::arg("uniqueId"))
|
||||
.def("initialize", (void (TcpBootstrap::*)(const std::string&)) & TcpBootstrap::initialize,
|
||||
nb::arg("ifIpPortTrio"));
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
.value("CudaIpc", Transport::CudaIpc)
|
||||
.value("IB0", Transport::IB0)
|
||||
.value("IB1", Transport::IB1)
|
||||
.value("IB2", Transport::IB2)
|
||||
.value("IB3", Transport::IB3)
|
||||
.value("IB4", Transport::IB4)
|
||||
.value("IB5", Transport::IB5)
|
||||
.value("IB6", Transport::IB6)
|
||||
.value("IB7", Transport::IB7)
|
||||
.value("NumTransports", Transport::NumTransports);
|
||||
|
||||
nb::class_<TransportFlags>(m, "TransportFlags")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def("has", &TransportFlags::has, nb::arg("transport"))
|
||||
.def("none", &TransportFlags::none)
|
||||
.def("any", &TransportFlags::any)
|
||||
.def("all", &TransportFlags::all)
|
||||
.def("count", &TransportFlags::count)
|
||||
.def(nb::self |= nb::self)
|
||||
.def(nb::self | nb::self)
|
||||
.def(nb::self | Transport())
|
||||
.def(nb::self &= nb::self)
|
||||
.def(nb::self & nb::self)
|
||||
.def(nb::self & Transport())
|
||||
.def(nb::self ^= nb::self)
|
||||
.def(nb::self ^ nb::self)
|
||||
.def(nb::self ^ Transport())
|
||||
.def(~nb::self)
|
||||
.def(nb::self == nb::self)
|
||||
.def(nb::self != nb::self);
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", &RegisteredMemory::data)
|
||||
.def("size", &RegisteredMemory::size)
|
||||
.def("rank", &RegisteredMemory::rank)
|
||||
.def("transports", &RegisteredMemory::transports)
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Connection>(m, "Connection")
|
||||
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
||||
nb::arg("size"))
|
||||
.def(
|
||||
"update_and_sync",
|
||||
[](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) {
|
||||
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
|
||||
},
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush)
|
||||
.def("remote_rank", &Connection::remoteRank)
|
||||
.def("tag", &Connection::tag)
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>>(), nb::arg("bootstrap"))
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
|
||||
nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("transport"))
|
||||
.def("setup", &Communicator::setup);
|
||||
}
|
||||
|
||||
NB_MODULE(mscclpp, m) {
|
||||
register_error(m);
|
||||
register_proxy_channel(m);
|
||||
register_fifo(m);
|
||||
register_semaphore(m);
|
||||
register_config(m);
|
||||
register_utils(m);
|
||||
register_core(m);
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
black
|
||||
isort
|
||||
|
||||
wheel
|
||||
|
||||
pytest
|
||||
PyHamcrest
|
||||
|
||||
nanobind
|
||||
|
||||
torch
|
||||
41
python/error_py.cpp
Normal file
41
python/error_py.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_error(nb::module_& m) {
|
||||
nb::enum_<ErrorCode>(m, "ErrorCode")
|
||||
.value("SystemError", ErrorCode::SystemError)
|
||||
.value("InternalError", ErrorCode::InternalError)
|
||||
.value("RemoteError", ErrorCode::RemoteError)
|
||||
.value("InvalidUsage", ErrorCode::InvalidUsage)
|
||||
.value("Timeout", ErrorCode::Timeout)
|
||||
.value("Aborted", ErrorCode::Aborted);
|
||||
|
||||
nb::class_<BaseError>(m, "BaseError")
|
||||
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
|
||||
.def("get_error_code", &BaseError::getErrorCode)
|
||||
.def("what", &BaseError::what);
|
||||
|
||||
nb::class_<Error, BaseError>(m, "Error")
|
||||
.def(nb::init<const std::string&, ErrorCode>(), nb::arg("message"), nb::arg("errorCode"))
|
||||
.def("get_error_code", &Error::getErrorCode);
|
||||
|
||||
nb::class_<SysError, BaseError>(m, "SysError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
|
||||
nb::class_<CudaError, BaseError>(m, "CudaError")
|
||||
.def(nb::init<const std::string&, cudaError_t>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
|
||||
nb::class_<CuError, BaseError>(m, "CuError")
|
||||
.def(nb::init<const std::string&, CUresult>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
|
||||
nb::class_<IbError, BaseError>(m, "IbError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
}
|
||||
91
python/examples/bootstrap.py
Normal file
91
python/examples/bootstrap.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import mscclpp
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import logging
|
||||
import torch
|
||||
import sys
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB0,
|
||||
mscclpp.Transport.IB1,
|
||||
mscclpp.Transport.IB2,
|
||||
mscclpp.Transport.IB3,
|
||||
mscclpp.Transport.IB4,
|
||||
mscclpp.Transport.IB5,
|
||||
mscclpp.Transport.IB6,
|
||||
mscclpp.Transport.IB7,
|
||||
]
|
||||
|
||||
|
||||
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
simple_proxy_channels = []
|
||||
connections = []
|
||||
remote_memories = []
|
||||
memory = torch.zeros(element_size, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
|
||||
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
||||
|
||||
for r in range(world_size):
|
||||
if r == rank:
|
||||
continue
|
||||
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
|
||||
connections.append(conn)
|
||||
comm.send_memory_on_setup(reg_mem, r, 0)
|
||||
remote_mem = comm.recv_memory_on_setup(r, 0)
|
||||
remote_memories.append(remote_mem)
|
||||
comm.setup()
|
||||
|
||||
for i, conn in enumerate(connections):
|
||||
proxy_channel = mscclpp.SimpleProxyChannel(
|
||||
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
|
||||
proxy_service.add_memory(remote_memories[i].get()),
|
||||
proxy_service.add_memory(reg_mem),
|
||||
)
|
||||
simple_proxy_channels.append(proxy_channel)
|
||||
comm.setup()
|
||||
return simple_proxy_channels
|
||||
|
||||
|
||||
def run(rank, args):
|
||||
world_size = args.gpu_number
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
comm = mscclpp.Communicator(boot)
|
||||
proxy_service = mscclpp.ProxyService(comm)
|
||||
|
||||
logging.info("Rank: %d, setting up connections", rank)
|
||||
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
||||
|
||||
logging.info("Rank: %d, starting proxy service", rank)
|
||||
proxy_service.start_proxy()
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("if_ip_port_trio", type=str)
|
||||
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
||||
parser.add_argument("-g", "--gpu_number", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
processes = []
|
||||
|
||||
for rank in range(args.gpu_number):
|
||||
p = mp.Process(target=run, args=(rank, args))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
python/examples/config.py
Normal file
12
python/examples/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main():
|
||||
config = mscclpp.Config.get_instance()
|
||||
config.set_bootstrap_connection_timeout_config(15)
|
||||
timeout = config.get_bootstrap_connection_timeout_config()
|
||||
assert timeout == 15
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
python/examples/send_recv.py
Normal file
81
python/examples/send_recv.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import mscclpp
|
||||
import argparse
|
||||
import time
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.root:
|
||||
rank = 0
|
||||
else:
|
||||
rank = 1
|
||||
|
||||
boot = mscclpp.TcpBootstrap.create(rank, 2)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
|
||||
comm = mscclpp.Communicator(boot)
|
||||
|
||||
if args.gpu:
|
||||
import torch
|
||||
|
||||
print("Allocating GPU memory")
|
||||
memory = torch.zeros(args.num_elements, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
else:
|
||||
from array import array
|
||||
|
||||
print("Allocating host memory")
|
||||
memory = array("i", [0] * args.num_elements)
|
||||
ptr, elements = memory.buffer_info()
|
||||
size = elements * memory.itemsize
|
||||
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
|
||||
|
||||
conn = comm.connect_on_setup((rank + 1) % 2, 0, mscclpp.Transport.IB0)
|
||||
|
||||
other_reg_mem = None
|
||||
if rank == 0:
|
||||
other_reg_mem = comm.recv_memory_on_setup((rank + 1) % 2, 0)
|
||||
else:
|
||||
comm.send_memory_on_setup(my_reg_mem, (rank + 1) % 2, 0)
|
||||
|
||||
comm.setup()
|
||||
|
||||
if rank == 0:
|
||||
other_reg_mem = other_reg_mem.get()
|
||||
|
||||
if rank == 0:
|
||||
for i in range(args.num_elements):
|
||||
memory[i] = i + 1
|
||||
conn.write(other_reg_mem, 0, my_reg_mem, 0, size)
|
||||
print("Done sending")
|
||||
else:
|
||||
print("Checking for correctness")
|
||||
# polling
|
||||
for _ in range(args.polling_num):
|
||||
all_correct = True
|
||||
for i in range(args.num_elements):
|
||||
if memory[i] != i + 1:
|
||||
all_correct = False
|
||||
print(f"Error: Mismatch at index {i}: expected {i + 1}, got {memory[i]}")
|
||||
break
|
||||
if all_correct:
|
||||
print("All data matched expected values")
|
||||
break
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("if_ip_port_trio", type=str)
|
||||
parser.add_argument("-r", "--root", action="store_true")
|
||||
parser.add_argument("-n", "--num-elements", type=int, default=10)
|
||||
parser.add_argument("--gpu", action="store_true")
|
||||
parser.add_argument("--polling_num", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
14
python/examples/utils.py
Normal file
14
python/examples/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import mscclpp
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def main():
|
||||
timer = mscclpp.Timer()
|
||||
timer.reset()
|
||||
time.sleep(2)
|
||||
assert timer.elapsed() >= 2000000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
python/fifo_py.cpp
Normal file
25
python/fifo_py.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include <mscclpp/fifo.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
|
||||
|
||||
nb::class_<DeviceProxyFifo>(m, "DeviceProxyFifo")
|
||||
.def_rw("triggers", &DeviceProxyFifo::triggers)
|
||||
.def_rw("tail_replica", &DeviceProxyFifo::tailReplica)
|
||||
.def_rw("head", &DeviceProxyFifo::head);
|
||||
|
||||
nb::class_<HostProxyFifo>(m, "HostProxyFifo")
|
||||
.def(nb::init<>())
|
||||
.def("poll", &HostProxyFifo::poll, nb::arg("trigger"))
|
||||
.def("pop", &HostProxyFifo::pop)
|
||||
.def("flush_tail", &HostProxyFifo::flushTail, nb::arg("sync") = false)
|
||||
.def("device_fifo", &HostProxyFifo::deviceFifo);
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
isort src
|
||||
black src
|
||||
|
||||
clang-format -i $(find src -name '*.cpp' -or -name '*.h')
|
||||
|
||||
34
python/proxy_channel_py.cpp
Normal file
34
python/proxy_channel_py.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/proxy_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_proxy_channel(nb::module_& m) {
|
||||
nb::class_<BaseProxyService>(m, "BaseProxyService")
|
||||
.def("start_proxy", &BaseProxyService::startProxy)
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
.def(nb::init<Communicator&>(), nb::arg("comm"))
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("connection"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("device_channel", &ProxyService::deviceChannel, nb::arg("id"));
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, DeviceProxyFifo>(), nb::arg("semaphoreId"),
|
||||
nb::arg("semaphore"), nb::arg("fifo"));
|
||||
|
||||
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
|
||||
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxy_chan"), nb::arg("dst"), nb::arg("src"))
|
||||
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxy_chan"));
|
||||
};
|
||||
42
python/semaphore_py.cpp
Normal file
42
python/semaphore_py.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
host2DeviceSemaphore
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
.def("signal", &Host2DeviceSemaphore::signal)
|
||||
.def("device_handle", &Host2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
.def("wait", &Host2HostSemaphore::wait);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
|
||||
smDevice2DeviceSemaphore
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("device_handle", &SmDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore::DeviceHandle>(smDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import logging
|
||||
from setuptools import Extension, find_packages, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
import subprocess
|
||||
|
||||
THIS_DIR = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
class CustomExt(Extension):
|
||||
def __init__(self, name):
|
||||
# don't invoke the original build_ext for this special extension
|
||||
super().__init__(name, sources=[])
|
||||
|
||||
|
||||
class custom_build_ext(build_ext):
|
||||
def run(self):
|
||||
for ext in self.extensions:
|
||||
if isinstance(ext, CustomExt):
|
||||
self.build_extension(ext)
|
||||
else:
|
||||
super().run()
|
||||
|
||||
def build_extension(self, ext):
|
||||
if sys.platform == "darwin":
|
||||
return
|
||||
|
||||
# these dirs will be created in build_py, so if you don't have
|
||||
# any python sources to bundle, the dirs will be missing
|
||||
build_temp = os.path.abspath(self.build_temp)
|
||||
os.makedirs(build_temp, exist_ok=True)
|
||||
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["cmake", "-S", THIS_DIR, "-B", build_temp],
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
subprocess.check_output(
|
||||
["cmake", "--build", build_temp],
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(e.output.decode())
|
||||
raise
|
||||
|
||||
libname = os.path.basename(self.get_ext_fullpath(ext.name))
|
||||
|
||||
target_dir = os.path.join(
|
||||
os.path.dirname(self.get_ext_fullpath(ext.name)),
|
||||
"mscclpp",
|
||||
)
|
||||
|
||||
shutil.copy(
|
||||
os.path.join(build_temp, "libmscclpp.so"),
|
||||
target_dir,
|
||||
)
|
||||
|
||||
shutil.copy(
|
||||
os.path.join(build_temp, libname),
|
||||
target_dir,
|
||||
)
|
||||
|
||||
|
||||
setup(
|
||||
name='mscclpp',
|
||||
version='0.1.0',
|
||||
description='Python bindings for mscclpp',
|
||||
# packages=['mscclpp'],
|
||||
package_dir={'': 'src'},
|
||||
packages=find_packages(where='./src'),
|
||||
ext_modules=[CustomExt('_py_mscclpp')],
|
||||
cmdclass={
|
||||
"build_ext": custom_build_ext,
|
||||
},
|
||||
install_requires=[
|
||||
'torch',
|
||||
'nanobind',
|
||||
],
|
||||
)
|
||||
@@ -1,7 +0,0 @@
|
||||
---
|
||||
Language : Cpp
|
||||
BasedOnStyle : google
|
||||
BinPackParameters: false
|
||||
BinPackArguments : false
|
||||
AlignAfterOpenBracket : AlwaysBreak
|
||||
...
|
||||
@@ -1,405 +0,0 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <mscclpp.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
// This is a poorman's substitute for std::format, which is a C++20 feature.
|
||||
template <typename... Args>
|
||||
std::string string_format(const std::string& format, Args... args) {
|
||||
// Shutup format warning.
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wformat-security"
|
||||
|
||||
// Dry-run to the get the buffer size:
|
||||
// Extra space for '\0'
|
||||
int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1;
|
||||
if (size_s <= 0) {
|
||||
throw std::runtime_error("Error during formatting.");
|
||||
}
|
||||
|
||||
// allocate buffer
|
||||
auto size = static_cast<size_t>(size_s);
|
||||
std::unique_ptr<char[]> buf(new char[size]);
|
||||
|
||||
// actually format
|
||||
std::snprintf(buf.get(), size, format.c_str(), args...);
|
||||
|
||||
// Bulid the return string.
|
||||
// We don't want the '\0' inside
|
||||
return std::string(buf.get(), buf.get() + size - 1);
|
||||
|
||||
#pragma GCC diagnostic pop
|
||||
}
|
||||
|
||||
// Maybe return the value, maybe throw an exception.
|
||||
template <typename... Args>
|
||||
void checkResult(
|
||||
mscclppResult_t status, const std::string& format, Args... args) {
|
||||
switch (status) {
|
||||
case mscclppSuccess:
|
||||
return;
|
||||
|
||||
case mscclppUnhandledCudaError:
|
||||
case mscclppSystemError:
|
||||
case mscclppInternalError:
|
||||
case mscclppRemoteError:
|
||||
case mscclppInProgress:
|
||||
case mscclppNumResults:
|
||||
throw std::runtime_error(
|
||||
string_format(format, args...) + " : " +
|
||||
std::string(mscclppGetErrorString(status)));
|
||||
|
||||
case mscclppInvalidArgument:
|
||||
case mscclppInvalidUsage:
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
string_format(format, args...) + " : " +
|
||||
std::string(mscclppGetErrorString(status)));
|
||||
}
|
||||
}
|
||||
|
||||
#define RETRY(C, ...) \
|
||||
{ \
|
||||
mscclppResult_t res; \
|
||||
do { \
|
||||
res = (C); \
|
||||
} while (res == mscclppInProgress); \
|
||||
checkResult(res, __VA_ARGS__); \
|
||||
}
|
||||
|
||||
// Maybe return the value, maybe throw an exception.
|
||||
template <typename Val, typename... Args>
|
||||
Val maybe(
|
||||
mscclppResult_t status, Val val, const std::string& format, Args... args) {
|
||||
checkResult(status, format, args...);
|
||||
return val;
|
||||
}
|
||||
|
||||
// Wrapper around connection state.
|
||||
struct _Comm {
|
||||
int _rank;
|
||||
int _world_size;
|
||||
mscclppComm_t _handle;
|
||||
bool _is_open;
|
||||
bool _proxies_running;
|
||||
|
||||
public:
|
||||
_Comm(int rank, int world_size, mscclppComm_t handle)
|
||||
: _rank(rank),
|
||||
_world_size(world_size),
|
||||
_handle(handle),
|
||||
_is_open(true),
|
||||
_proxies_running(false) {}
|
||||
|
||||
~_Comm() { close(); }
|
||||
|
||||
// Close should be safe to call on a closed handle.
|
||||
void close() {
|
||||
if (_is_open) {
|
||||
if (_proxies_running) {
|
||||
mscclppProxyStop(_handle);
|
||||
_proxies_running = false;
|
||||
}
|
||||
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
|
||||
_handle = NULL;
|
||||
_is_open = false;
|
||||
_rank = -1;
|
||||
_world_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
void check_open() {
|
||||
if (!_is_open) {
|
||||
throw std::invalid_argument("_Comm is not open");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct _P2PHandle {
|
||||
struct mscclppRegisteredMemoryP2P _rmP2P;
|
||||
struct mscclppIbMr _ibmr;
|
||||
|
||||
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
|
||||
|
||||
_P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) {
|
||||
_rmP2P = p2p;
|
||||
if (_rmP2P.IbMr != nullptr) {
|
||||
_ibmr = *_rmP2P.IbMr;
|
||||
_rmP2P.IbMr = &_ibmr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
nb::callable _log_callback;
|
||||
|
||||
void _LogHandler(const char* msg) {
|
||||
if (_log_callback) {
|
||||
nb::gil_scoped_acquire guard;
|
||||
_log_callback(msg);
|
||||
}
|
||||
}
|
||||
|
||||
static const std::string DOC_MscclppUniqueId =
|
||||
"MSCCLPP Unique Id; used by the MPI Interface";
|
||||
|
||||
static const std::string DOC__Comm = "MSCCLPP Communications Handle";
|
||||
|
||||
static const std::string DOC__P2PHandle = "MSCCLPP P2P MR Handle";
|
||||
|
||||
NB_MODULE(_py_mscclpp, m) {
|
||||
m.doc() = "Python bindings for MSCCLPP: which is not NCCL";
|
||||
|
||||
m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES;
|
||||
|
||||
m.def("_bind_log_handler", [](nb::callable cb) -> void {
|
||||
_log_callback = nb::borrow<nb::callable>(cb);
|
||||
mscclppSetLogHandler(_LogHandler);
|
||||
});
|
||||
m.def("_release_log_handler", []() -> void {
|
||||
_log_callback.reset();
|
||||
mscclppSetLogHandler(mscclppDefaultLogHandler);
|
||||
});
|
||||
|
||||
nb::enum_<mscclppTransport_t>(m, "TransportType")
|
||||
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
|
||||
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
|
||||
.value("IB", mscclppTransport_t::mscclppTransportIB);
|
||||
|
||||
nb::class_<mscclppUniqueId>(m, "MscclppUniqueId")
|
||||
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
|
||||
.def_static(
|
||||
"from_context",
|
||||
[]() -> mscclppUniqueId {
|
||||
mscclppUniqueId uniqueId;
|
||||
return maybe(
|
||||
mscclppGetUniqueId(&uniqueId),
|
||||
uniqueId,
|
||||
"Failed to get MSCCLP Unique Id.");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def_static(
|
||||
"from_bytes",
|
||||
[](nb::bytes source) -> mscclppUniqueId {
|
||||
if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) {
|
||||
throw std::invalid_argument(string_format(
|
||||
"Requires exactly %d bytes; found %d",
|
||||
MSCCLPP_UNIQUE_ID_BYTES,
|
||||
source.size()));
|
||||
}
|
||||
|
||||
mscclppUniqueId uniqueId;
|
||||
std::memcpy(
|
||||
uniqueId.internal, source.c_str(), sizeof(uniqueId.internal));
|
||||
return uniqueId;
|
||||
})
|
||||
.def("bytes", [](mscclppUniqueId id) {
|
||||
return nb::bytes(id.internal, sizeof(id.internal));
|
||||
});
|
||||
|
||||
nb::class_<_P2PHandle>(m, "_P2PHandle")
|
||||
.def_ro_static("__doc__", &DOC__P2PHandle);
|
||||
|
||||
nb::class_<_Comm>(m, "_Comm")
|
||||
.def_ro_static("__doc__", &DOC__Comm)
|
||||
.def_static(
|
||||
"init_rank_from_address",
|
||||
[](const std::string& address, int rank, int world_size) -> _Comm* {
|
||||
mscclppComm_t handle;
|
||||
checkResult(
|
||||
mscclppCommInitRank(&handle, world_size, address.c_str(), rank),
|
||||
"Failed to initialize comms: %s rank=%d world_size=%d",
|
||||
address,
|
||||
rank,
|
||||
world_size);
|
||||
return new _Comm(rank, world_size, handle);
|
||||
},
|
||||
nb::rv_policy::take_ownership,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"address"_a,
|
||||
"rank"_a,
|
||||
"world_size"_a,
|
||||
"Initialize comms given an IP address, rank, and world_size")
|
||||
.def_static(
|
||||
"init_rank_from_id",
|
||||
[](const mscclppUniqueId& id, int rank, int world_size) -> _Comm* {
|
||||
mscclppComm_t handle;
|
||||
checkResult(
|
||||
mscclppCommInitRankFromId(&handle, world_size, id, rank),
|
||||
"Failed to initialize comms: %02X%s rank=%d world_size=%d",
|
||||
id.internal,
|
||||
rank,
|
||||
world_size);
|
||||
return new _Comm(rank, world_size, handle);
|
||||
},
|
||||
nb::rv_policy::take_ownership,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"id"_a,
|
||||
"rank"_a,
|
||||
"world_size"_a,
|
||||
"Initialize comms given u UniqueID, rank, and world_size")
|
||||
.def(
|
||||
"opened",
|
||||
[](_Comm& comm) -> bool { return comm._is_open; },
|
||||
"Is this comm object opened?")
|
||||
.def(
|
||||
"closed",
|
||||
[](_Comm& comm) -> bool { return !comm._is_open; },
|
||||
"Is this comm object closed?")
|
||||
.def_ro("rank", &_Comm::_rank)
|
||||
.def_ro("world_size", &_Comm::_world_size)
|
||||
.def(
|
||||
"register_buffer",
|
||||
[](_Comm& comm,
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size) -> std::vector<_P2PHandle> {
|
||||
comm.check_open();
|
||||
mscclppRegisteredMemory regMem;
|
||||
checkResult(
|
||||
mscclppRegisterBuffer(
|
||||
comm._handle,
|
||||
reinterpret_cast<void*>(local_buff),
|
||||
buff_size,
|
||||
®Mem),
|
||||
"Registering buffer failed");
|
||||
|
||||
std::vector<_P2PHandle> handles;
|
||||
for (const auto& p2p : regMem.p2p) {
|
||||
handles.push_back(_P2PHandle(p2p));
|
||||
}
|
||||
return handles;
|
||||
},
|
||||
"local_buf"_a,
|
||||
"buff_size"_a,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Register a buffer for P2P transfers.")
|
||||
.def(
|
||||
"connect",
|
||||
[](_Comm& comm,
|
||||
int remote_rank,
|
||||
int tag,
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size,
|
||||
mscclppTransport_t transport_type) -> void {
|
||||
comm.check_open();
|
||||
RETRY(
|
||||
mscclppConnect(
|
||||
comm._handle,
|
||||
remote_rank,
|
||||
tag,
|
||||
reinterpret_cast<void*>(local_buff),
|
||||
buff_size,
|
||||
transport_type,
|
||||
NULL // ibDev
|
||||
),
|
||||
"Connect failed");
|
||||
},
|
||||
"remote_rank"_a,
|
||||
"tag"_a,
|
||||
"local_buf"_a,
|
||||
"buff_size"_a,
|
||||
"transport_type"_a,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Attach a local buffer to a remote connection.")
|
||||
.def(
|
||||
"connection_setup",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
RETRY(
|
||||
mscclppConnectionSetup(comm._handle),
|
||||
"Failed to setup MSCCLPP connection");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Run connection setup for MSCCLPP.")
|
||||
.def(
|
||||
"launch_proxies",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
if (comm._proxies_running) {
|
||||
throw std::invalid_argument("Proxy Threads Already Running");
|
||||
}
|
||||
checkResult(
|
||||
mscclppProxyLaunch(comm._handle),
|
||||
"Failed to launch MSCCLPP proxy");
|
||||
comm._proxies_running = true;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Start the MSCCLPP proxy.")
|
||||
.def(
|
||||
"stop_proxies",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
if (comm._proxies_running) {
|
||||
checkResult(
|
||||
mscclppProxyStop(comm._handle),
|
||||
"Failed to stop MSCCLPP proxy");
|
||||
comm._proxies_running = false;
|
||||
}
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Start the MSCCLPP proxy.")
|
||||
.def("close", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("__del__", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"bootstrap_all_gather_int",
|
||||
[](_Comm& comm, int val) -> std::vector<int> {
|
||||
std::vector<int> buf(comm._world_size);
|
||||
buf[comm._rank] = val;
|
||||
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
|
||||
return buf;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"val"_a,
|
||||
"all-gather ints over the bootstrap connection.")
|
||||
.def(
|
||||
"all_gather_bytes",
|
||||
[](_Comm& comm, nb::bytes& item) -> std::vector<nb::bytes> {
|
||||
// First, all-gather the sizes of all bytes.
|
||||
std::vector<size_t> sizes(comm._world_size);
|
||||
sizes[comm._rank] = item.size();
|
||||
checkResult(
|
||||
mscclppBootstrapAllGather(
|
||||
comm._handle, sizes.data(), sizeof(size_t)),
|
||||
"bootstrapAllGather failed.");
|
||||
|
||||
// Next, find the largest message to send.
|
||||
size_t max_size = *std::max_element(sizes.begin(), sizes.end());
|
||||
|
||||
// Allocate an all-gather buffer large enough for max * world_size.
|
||||
std::shared_ptr<char[]> data_buf(
|
||||
new char[max_size * comm._world_size]);
|
||||
|
||||
// Copy the local item into the buffer.
|
||||
std::memcpy(
|
||||
&data_buf[comm._rank * max_size], item.c_str(), item.size());
|
||||
|
||||
// all-gather the data buffer.
|
||||
checkResult(
|
||||
mscclppBootstrapAllGather(
|
||||
comm._handle, data_buf.get(), max_size),
|
||||
"bootstrapAllGather failed.");
|
||||
|
||||
// Build a response vector.
|
||||
std::vector<nb::bytes> ret;
|
||||
for (int i = 0; i < comm._world_size; ++i) {
|
||||
// Copy out the relevant range of each item.
|
||||
ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i]));
|
||||
}
|
||||
return ret;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"item"_a,
|
||||
"all-gather bytes over the bootstrap connection; sizes do not need "
|
||||
"to match.");
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
from typing import Any, Optional, final
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
from . import _py_mscclpp
|
||||
|
||||
__all__ = (
|
||||
"Comm",
|
||||
"MscclppUniqueId",
|
||||
"MSCCLPP_UNIQUE_ID_BYTES",
|
||||
"TransportType",
|
||||
)
|
||||
|
||||
_Comm = _py_mscclpp._Comm
|
||||
_P2PHandle = _py_mscclpp._P2PHandle
|
||||
TransportType = _py_mscclpp.TransportType
|
||||
|
||||
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
|
||||
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES
|
||||
|
||||
|
||||
def _mscclpp_log_cb(msg: str) -> None:
|
||||
"""Log callback hook called from inside _py_mscclpp."""
|
||||
|
||||
# Attempt to parse out the original log level:
|
||||
level = logging.INFO
|
||||
if match := re.search(r"MSCCLPP (\w+)", msg):
|
||||
level = logging._nameToLevel.get(match.group(1), logging.INFO)
|
||||
|
||||
# actually log the event.
|
||||
logger.log(level, msg)
|
||||
|
||||
|
||||
# The known log levels used by MSCCLPP.
|
||||
# Set in os.environ['MSCCLPP_DEBUG'] and only parsed on first init.
|
||||
MSCCLPP_LOG_LEVELS: set[str] = {
|
||||
"DEBUG",
|
||||
"INFO",
|
||||
"WARN",
|
||||
"ABORT",
|
||||
"TRACE",
|
||||
}
|
||||
|
||||
|
||||
def _setup_logging(level: str = "INFO"):
|
||||
"""Setup log hooks for the C library."""
|
||||
level = level.upper()
|
||||
if level not in MSCCLPP_LOG_LEVELS:
|
||||
level = "INFO"
|
||||
os.environ["MSCCLPP_DEBUG"] = level
|
||||
|
||||
_py_mscclpp._bind_log_handler(_mscclpp_log_cb)
|
||||
# needed to prevent a segfault at exit.
|
||||
atexit.register(_py_mscclpp._release_log_handler)
|
||||
|
||||
|
||||
_setup_logging()
|
||||
|
||||
|
||||
@final
|
||||
class Comm:
|
||||
"""Comm object; represents a mscclpp connection."""
|
||||
|
||||
_comm: _Comm
|
||||
|
||||
@staticmethod
|
||||
def init_rank_from_address(
|
||||
address: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
):
|
||||
"""Initialize a Comm from an address.
|
||||
|
||||
:param address: the address as a string, with optional port.
|
||||
:param rank: this Comm's rank.
|
||||
:param world_size: the total world size.
|
||||
:param port: (optional) port, appended to address.
|
||||
:return: a newly initialized Comm.
|
||||
"""
|
||||
if port is not None:
|
||||
address = f"{address}:{port}"
|
||||
return Comm(
|
||||
_comm=_Comm.init_rank_from_address(
|
||||
address=address,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, *, _comm: _Comm):
|
||||
"""Construct a Comm object wrapping an internal _Comm handle."""
|
||||
self._comm = _comm
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._comm:
|
||||
self._comm.close()
|
||||
self._comm = None
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""Return the rank of the Comm.
|
||||
|
||||
Assumes the Comm is open.
|
||||
"""
|
||||
return self._comm.rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
"""Return the world_size of the Comm.
|
||||
|
||||
Assumes the Comm is open.
|
||||
"""
|
||||
return self._comm.world_size
|
||||
|
||||
def bootstrap_all_gather_int(self, val: int) -> list[int]:
|
||||
"""AllGather an int value through the bootstrap interface."""
|
||||
return self._comm.bootstrap_all_gather_int(val)
|
||||
|
||||
def all_gather_bytes(self, item: bytes) -> list[bytes]:
|
||||
"""AllGather bytes (of different sizes) through the bootstrap interface.
|
||||
|
||||
:param item: the bytes object for this rank.
|
||||
:return: a list of bytes objects; the ret[rank] object will be a new copy.
|
||||
"""
|
||||
return self._comm.all_gather_bytes(item)
|
||||
|
||||
def all_gather_json(self, item: Any) -> list[Any]:
|
||||
"""AllGather JSON objects through the bootstrap interface.
|
||||
|
||||
:param item: the JSON object for this rank.
|
||||
:return: a list of JSON objects; the ret[rank] object will be a new copy.
|
||||
"""
|
||||
return [
|
||||
json.loads(b.decode("utf-8"))
|
||||
for b in self.all_gather_bytes(bytes(json.dumps(item), "utf-8"))
|
||||
]
|
||||
|
||||
def all_gather_pickle(self, item: Any) -> list[Any]:
|
||||
"""AllGather pickle-able objects through the bootstrap interface.
|
||||
|
||||
:param item: the object for this rank.
|
||||
:return: a list of de-pickled objects. Note, the ret[rank] item will be a new copy.
|
||||
"""
|
||||
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
|
||||
|
||||
def connect(
|
||||
self,
|
||||
remote_rank: int,
|
||||
tag: int,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
transport: int,
|
||||
) -> None:
|
||||
self._comm.connect(
|
||||
remote_rank,
|
||||
tag,
|
||||
data_ptr,
|
||||
data_size,
|
||||
transport,
|
||||
)
|
||||
|
||||
def connection_setup(self) -> None:
|
||||
self._comm.connection_setup()
|
||||
|
||||
def launch_proxies(self) -> None:
|
||||
self._comm.launch_proxies()
|
||||
|
||||
def stop_proxies(self) -> None:
|
||||
self._comm.stop_proxies()
|
||||
|
||||
def register_buffer(
|
||||
self,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
) -> list[_P2PHandle]:
|
||||
return [
|
||||
P2PHandle(self, h)
|
||||
for h in self._comm.register_buffer(
|
||||
data_ptr,
|
||||
data_size,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class P2PHandle:
|
||||
_comm: Comm
|
||||
_handle: _P2PHandle
|
||||
|
||||
def __init__(self, comm: Comm, handle: _P2PHandle):
|
||||
self._comm = comm
|
||||
self._handle = handle
|
||||
@@ -1,86 +0,0 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import hamcrest
|
||||
|
||||
import mscclpp
|
||||
|
||||
MOD_DIR = os.path.dirname(__file__)
|
||||
TESTS_DIR = os.path.join(MOD_DIR, "tests")
|
||||
|
||||
|
||||
class UniqueIdTest(unittest.TestCase):
|
||||
def test_no_constructor(self) -> None:
|
||||
hamcrest.assert_that(
|
||||
hamcrest.calling(mscclpp.MscclppUniqueId).with_args(),
|
||||
hamcrest.raises(
|
||||
TypeError,
|
||||
"no constructor",
|
||||
),
|
||||
)
|
||||
|
||||
def test_getUniqueId(self) -> None:
|
||||
myId = mscclpp.MscclppUniqueId.from_context()
|
||||
|
||||
hamcrest.assert_that(
|
||||
myId.bytes(),
|
||||
hamcrest.has_length(mscclpp.MSCCLPP_UNIQUE_ID_BYTES),
|
||||
)
|
||||
|
||||
# from_bytes should work
|
||||
copy = mscclpp.MscclppUniqueId.from_bytes(myId.bytes())
|
||||
hamcrest.assert_that(
|
||||
copy.bytes(),
|
||||
hamcrest.equal_to(myId.bytes()),
|
||||
)
|
||||
|
||||
# bad size
|
||||
hamcrest.assert_that(
|
||||
hamcrest.calling(mscclpp.MscclppUniqueId.from_bytes).with_args(b"abc"),
|
||||
hamcrest.raises(
|
||||
ValueError,
|
||||
f"Requires exactly {mscclpp.MSCCLPP_UNIQUE_ID_BYTES} bytes; found 3",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CommsTest(unittest.TestCase):
|
||||
def test_all_gather(self) -> None:
|
||||
world_size = 2
|
||||
|
||||
tasks: list[concurrent.futures.Future[None]] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as pool:
|
||||
for rank in range(world_size):
|
||||
tasks.append(
|
||||
pool.submit(
|
||||
subprocess.check_output,
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"mscclpp.tests.bootstrap_test",
|
||||
f"--rank={rank}",
|
||||
f"--world_size={world_size}",
|
||||
],
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
)
|
||||
|
||||
errors = []
|
||||
for rank, f in enumerate(tasks):
|
||||
try:
|
||||
f.result()
|
||||
except subprocess.CalledProcessError as e:
|
||||
errors.append((rank, e.output))
|
||||
|
||||
if errors:
|
||||
parts = []
|
||||
for rank, content in errors:
|
||||
parts.append(
|
||||
f"[rank {rank}]: " + content.decode("utf-8", errors="ignore")
|
||||
)
|
||||
|
||||
raise AssertionError("\n\n".join(parts))
|
||||
@@ -1,145 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import hamcrest
|
||||
import torch
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
@dataclass
|
||||
class Example:
|
||||
rank: int
|
||||
|
||||
|
||||
def _test_bootstrap_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.bootstrap_all_gather_int(options.rank + 42),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
42,
|
||||
43,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _test_bootstrap_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
b"abc",
|
||||
b"abcabc",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _test_bootstrap_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_json({"rank": options.rank}),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
{"rank": 0},
|
||||
{"rank": 1},
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_json([options.rank, 42]),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
[0, 42],
|
||||
[1, 42],
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_pickle(Example(rank=options.rank)),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
Example(rank=0),
|
||||
Example(rank=1),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
comm.connection_setup()
|
||||
|
||||
|
||||
def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
rank = options.rank
|
||||
|
||||
buf = torch.zeros([options.world_size], dtype=torch.int64)
|
||||
buf[rank] = 42 + rank
|
||||
buf = buf.cuda().contiguous()
|
||||
|
||||
tag = 0
|
||||
|
||||
if rank:
|
||||
remote_rank = 0
|
||||
else:
|
||||
remote_rank = 1
|
||||
|
||||
comm.connect(
|
||||
remote_rank,
|
||||
tag,
|
||||
buf.data_ptr(),
|
||||
buf.element_size() * buf.numel(),
|
||||
mscclpp.TransportType.P2P,
|
||||
)
|
||||
|
||||
handles = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel())
|
||||
hamcrest.assert_that(
|
||||
handles,
|
||||
hamcrest.has_length(options.world_size - 1),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
comm.connection_setup()
|
||||
|
||||
comm.launch_proxies()
|
||||
comm.stop_proxies()
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rank", type=int, required=True)
|
||||
p.add_argument("--world_size", type=int, required=True)
|
||||
p.add_argument("--port", default=50000)
|
||||
options = p.parse_args()
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(options.rank)
|
||||
|
||||
comm_options = dict(
|
||||
address=f"127.0.0.1:{options.port}",
|
||||
rank=options.rank,
|
||||
world_size=options.world_size,
|
||||
)
|
||||
print(f"{comm_options=}", flush=True)
|
||||
|
||||
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
|
||||
# comm.connection_setup()
|
||||
|
||||
hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank))
|
||||
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
|
||||
|
||||
try:
|
||||
_test_bootstrap_allgather_int(options, comm)
|
||||
_test_bootstrap_allgather_bytes(options, comm)
|
||||
_test_bootstrap_allgather_json(options, comm)
|
||||
_test_bootstrap_allgather_pickle(options, comm)
|
||||
_test_rm(options, comm)
|
||||
finally:
|
||||
comm.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
pip install -e .
|
||||
|
||||
cd src
|
||||
pytest -vs mscclpp
|
||||
23
python/utils_py.cpp
Normal file
23
python/utils_py.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_utils(nb::module_& m) {
|
||||
nb::class_<Timer>(m, "Timer")
|
||||
.def(nb::init<int>(), nb::arg("timeout") = -1)
|
||||
.def("elapsed", &Timer::elapsed)
|
||||
.def("set", &Timer::set, nb::arg("timeout"))
|
||||
.def("reset", &Timer::reset)
|
||||
.def("print", &Timer::print, nb::arg("name"));
|
||||
|
||||
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
|
||||
|
||||
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
|
||||
}
|
||||
@@ -2,5 +2,5 @@
|
||||
# Licensed under the MIT license.
|
||||
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc)
|
||||
target_sources(mscclpp PRIVATE ${SOURCES})
|
||||
target_include_directories(mscclpp PRIVATE include)
|
||||
target_sources(mscclpp_obj PRIVATE ${SOURCES})
|
||||
target_include_directories(mscclpp_obj PRIVATE include)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <sys/resource.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <mscclpp/config.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <sstream>
|
||||
@@ -12,7 +13,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "api.h"
|
||||
#include "config.hpp"
|
||||
#include "debug.h"
|
||||
#include "socket.h"
|
||||
#include "utils_internal.hpp"
|
||||
@@ -36,13 +36,13 @@ struct ExtInfo {
|
||||
SocketAddress extAddressListen;
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP void BaseBootstrap::send(const std::vector<char>& data, int peer, int tag) {
|
||||
MSCCLPP_API_CPP void Bootstrap::send(const std::vector<char>& data, int peer, int tag) {
|
||||
size_t size = data.size();
|
||||
send((void*)&size, sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag + 1);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void BaseBootstrap::recv(std::vector<char>& data, int peer, int tag) {
|
||||
MSCCLPP_API_CPP void Bootstrap::recv(std::vector<char>& data, int peer, int tag) {
|
||||
size_t size;
|
||||
recv((void*)&size, sizeof(size_t), peer, tag);
|
||||
data.resize(size);
|
||||
@@ -55,12 +55,12 @@ struct UniqueIdInternal {
|
||||
};
|
||||
static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId");
|
||||
|
||||
class Bootstrap::Impl {
|
||||
class TcpBootstrap::Impl {
|
||||
public:
|
||||
Impl(int rank, int nRanks);
|
||||
~Impl();
|
||||
void initialize(const UniqueId uniqueId);
|
||||
void initialize(std::string ifIpPortTrio);
|
||||
void initialize(const UniqueId& uniqueId);
|
||||
void initialize(const std::string& ifIpPortTrio);
|
||||
void establishConnections();
|
||||
UniqueId createUniqueId();
|
||||
UniqueId getUniqueId() const;
|
||||
@@ -106,7 +106,7 @@ class Bootstrap::Impl {
|
||||
void netInit(std::string ipPortPair, std::string interface);
|
||||
};
|
||||
|
||||
Bootstrap::Impl::Impl(int rank, int nRanks)
|
||||
TcpBootstrap::Impl::Impl(int rank, int nRanks)
|
||||
: rank_(rank),
|
||||
nRanks_(nRanks),
|
||||
netInitialized(false),
|
||||
@@ -115,13 +115,13 @@ Bootstrap::Impl::Impl(int rank, int nRanks)
|
||||
abortFlagStorage_(new uint32_t(0)),
|
||||
abortFlag_(abortFlagStorage_.get()) {}
|
||||
|
||||
UniqueId Bootstrap::Impl::getUniqueId() const {
|
||||
UniqueId TcpBootstrap::Impl::getUniqueId() const {
|
||||
UniqueId ret;
|
||||
std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_));
|
||||
return ret;
|
||||
}
|
||||
|
||||
UniqueId Bootstrap::Impl::createUniqueId() {
|
||||
UniqueId TcpBootstrap::Impl::createUniqueId() {
|
||||
netInit("", "");
|
||||
getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
|
||||
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
|
||||
@@ -129,11 +129,11 @@ UniqueId Bootstrap::Impl::createUniqueId() {
|
||||
return getUniqueId();
|
||||
}
|
||||
|
||||
int Bootstrap::Impl::getRank() { return rank_; }
|
||||
int TcpBootstrap::Impl::getRank() { return rank_; }
|
||||
|
||||
int Bootstrap::Impl::getNranks() { return nRanks_; }
|
||||
int TcpBootstrap::Impl::getNranks() { return nRanks_; }
|
||||
|
||||
void Bootstrap::Impl::initialize(const UniqueId uniqueId) {
|
||||
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId) {
|
||||
netInit("", "");
|
||||
|
||||
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
|
||||
@@ -141,7 +141,7 @@ void Bootstrap::Impl::initialize(const UniqueId uniqueId) {
|
||||
establishConnections();
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::initialize(std::string ifIpPortTrio) {
|
||||
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) {
|
||||
// first check if it is a trio
|
||||
int nColons = 0;
|
||||
for (auto c : ifIpPortTrio) {
|
||||
@@ -170,7 +170,7 @@ void Bootstrap::Impl::initialize(std::string ifIpPortTrio) {
|
||||
establishConnections();
|
||||
}
|
||||
|
||||
Bootstrap::Impl::~Impl() {
|
||||
TcpBootstrap::Impl::~Impl() {
|
||||
if (abortFlag_) {
|
||||
*abortFlag_ = 1;
|
||||
}
|
||||
@@ -179,8 +179,8 @@ Bootstrap::Impl::~Impl() {
|
||||
}
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
|
||||
std::vector<SocketAddress>& rankAddressesRoot, int& rank) {
|
||||
void TcpBootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
|
||||
std::vector<SocketAddress>& rankAddressesRoot, int& rank) {
|
||||
ExtInfo info;
|
||||
SocketAddress zero;
|
||||
std::memset(&zero, 0, sizeof(SocketAddress));
|
||||
@@ -209,15 +209,15 @@ void Bootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketA
|
||||
rank = info.rank;
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
|
||||
const std::vector<SocketAddress>& rankAddressesRoot) {
|
||||
void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
|
||||
const std::vector<SocketAddress>& rankAddressesRoot) {
|
||||
int next = (peer + 1) % nRanks_;
|
||||
Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_);
|
||||
sock.connect();
|
||||
netSend(&sock, &rankAddresses[next], sizeof(SocketAddress));
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::bootstrapCreateRoot() {
|
||||
void TcpBootstrap::Impl::bootstrapCreateRoot() {
|
||||
listenSockRoot_ = std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0);
|
||||
listenSockRoot_->listen();
|
||||
uniqueId_.addr = listenSockRoot_->getAddr();
|
||||
@@ -232,7 +232,7 @@ void Bootstrap::Impl::bootstrapCreateRoot() {
|
||||
});
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::bootstrapRoot() {
|
||||
void TcpBootstrap::Impl::bootstrapRoot() {
|
||||
int numCollected = 0;
|
||||
std::vector<SocketAddress> rankAddresses(nRanks_, SocketAddress());
|
||||
// for initial rank <-> root information exchange
|
||||
@@ -266,7 +266,7 @@ void Bootstrap::Impl::bootstrapRoot() {
|
||||
TRACE(MSCCLPP_INIT, "DONE");
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
|
||||
void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
|
||||
if (netInitialized) return;
|
||||
if (!ipPortPair.empty()) {
|
||||
if (interface != "") {
|
||||
@@ -285,30 +285,30 @@ void Bootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
|
||||
} else {
|
||||
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
|
||||
if (ret <= 0) {
|
||||
throw Error("Bootstrap : no socket interface found", ErrorCode::InternalError);
|
||||
throw Error("TcpBootstrap : no socket interface found", ErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
|
||||
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
|
||||
std::sprintf(line, " %s:", netIfName_);
|
||||
SocketToString(&netIfAddr_, line + strlen(line));
|
||||
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
|
||||
INFO(MSCCLPP_INIT, "TcpBootstrap : Using%s", line);
|
||||
netInitialized = true;
|
||||
}
|
||||
|
||||
#define TIMEOUT(__exp) \
|
||||
do { \
|
||||
try { \
|
||||
__exp; \
|
||||
} catch (const Error& e) { \
|
||||
if (e.getErrorCode() == ErrorCode::Timeout) { \
|
||||
throw Error("Bootstrap connection timeout", ErrorCode::Timeout); \
|
||||
} \
|
||||
throw; \
|
||||
} \
|
||||
#define TIMEOUT(__exp) \
|
||||
do { \
|
||||
try { \
|
||||
__exp; \
|
||||
} catch (const Error& e) { \
|
||||
if (e.getErrorCode() == ErrorCode::Timeout) { \
|
||||
throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout); \
|
||||
} \
|
||||
throw; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
void Bootstrap::Impl::establishConnections() {
|
||||
void TcpBootstrap::Impl::establishConnections() {
|
||||
const int64_t connectionTimeoutUs = (int64_t)Config::getInstance()->getBootstrapConnectionTimeoutConfig() * 1000000;
|
||||
Timer timer;
|
||||
SocketAddress nextAddr;
|
||||
@@ -318,7 +318,7 @@ void Bootstrap::Impl::establishConnections() {
|
||||
|
||||
auto getLeftTime = [&]() {
|
||||
int64_t timeout = connectionTimeoutUs - timer.elapsed();
|
||||
if (timeout <= 0) throw Error("Bootstrap connection timeout", ErrorCode::Timeout);
|
||||
if (timeout <= 0) throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout);
|
||||
return timeout;
|
||||
};
|
||||
|
||||
@@ -377,7 +377,7 @@ void Bootstrap::Impl::establishConnections() {
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::allGather(void* allData, int size) {
|
||||
void TcpBootstrap::Impl::allGather(void* allData, int size) {
|
||||
char* data = static_cast<char*>(allData);
|
||||
int rank = rank_;
|
||||
int nRanks = nRanks_;
|
||||
@@ -401,7 +401,7 @@ void Bootstrap::Impl::allGather(void* allData, int size) {
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
|
||||
}
|
||||
|
||||
std::shared_ptr<Socket> Bootstrap::Impl::getPeerSendSocket(int peer, int tag) {
|
||||
std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerSendSocket(int peer, int tag) {
|
||||
auto it = peerSendSockets_.find(std::make_pair(peer, tag));
|
||||
if (it != peerSendSockets_.end()) {
|
||||
return it->second;
|
||||
@@ -414,7 +414,7 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerSendSocket(int peer, int tag) {
|
||||
return sock;
|
||||
}
|
||||
|
||||
std::shared_ptr<Socket> Bootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
|
||||
std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
|
||||
auto it = peerRecvSockets_.find(std::make_pair(peer, tag));
|
||||
if (it != peerRecvSockets_.end()) {
|
||||
return it->second;
|
||||
@@ -432,12 +432,12 @@ std::shared_ptr<Socket> Bootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
|
||||
}
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::netSend(Socket* sock, const void* data, int size) {
|
||||
void TcpBootstrap::Impl::netSend(Socket* sock, const void* data, int size) {
|
||||
sock->send(&size, sizeof(int));
|
||||
sock->send(const_cast<void*>(data), size);
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::netRecv(Socket* sock, void* data, int size) {
|
||||
void TcpBootstrap::Impl::netRecv(Socket* sock, void* data, int size) {
|
||||
int recvSize;
|
||||
sock->recv(&recvSize, sizeof(int));
|
||||
if (recvSize > size) {
|
||||
@@ -448,19 +448,19 @@ void Bootstrap::Impl::netRecv(Socket* sock, void* data, int size) {
|
||||
sock->recv(data, std::min(recvSize, size));
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::send(void* data, int size, int peer, int tag) {
|
||||
void TcpBootstrap::Impl::send(void* data, int size, int peer, int tag) {
|
||||
auto sock = getPeerSendSocket(peer, tag);
|
||||
netSend(sock.get(), data, size);
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::recv(void* data, int size, int peer, int tag) {
|
||||
void TcpBootstrap::Impl::recv(void* data, int size, int peer, int tag) {
|
||||
auto sock = getPeerRecvSocket(peer, tag);
|
||||
netRecv(sock.get(), data, size);
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); }
|
||||
void TcpBootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); }
|
||||
|
||||
void Bootstrap::Impl::close() {
|
||||
void TcpBootstrap::Impl::close() {
|
||||
listenSockRoot_.reset(nullptr);
|
||||
listenSock_.reset(nullptr);
|
||||
ringRecvSocket_.reset(nullptr);
|
||||
@@ -469,28 +469,32 @@ void Bootstrap::Impl::close() {
|
||||
peerRecvSockets_.clear();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Bootstrap::Bootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
|
||||
MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
|
||||
|
||||
MSCCLPP_API_CPP UniqueId Bootstrap::createUniqueId() { return pimpl_->createUniqueId(); }
|
||||
MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return pimpl_->createUniqueId(); }
|
||||
|
||||
MSCCLPP_API_CPP UniqueId Bootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
|
||||
MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
|
||||
|
||||
MSCCLPP_API_CPP int Bootstrap::getRank() { return pimpl_->getRank(); }
|
||||
MSCCLPP_API_CPP int TcpBootstrap::getRank() { return pimpl_->getRank(); }
|
||||
|
||||
MSCCLPP_API_CPP int Bootstrap::getNranks() { return pimpl_->getNranks(); }
|
||||
MSCCLPP_API_CPP int TcpBootstrap::getNranks() { return pimpl_->getNranks(); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::send(void* data, int size, int peer, int tag) {
|
||||
pimpl_->send(data, size, peer, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::recv(void* data, int size, int peer, int tag) { pimpl_->recv(data, size, peer, tag); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag) {
|
||||
pimpl_->recv(data, size, peer, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::initialize(std::string ipPortPair) { pimpl_->initialize(ipPortPair); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair) { pimpl_->initialize(ipPortPair); }
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::barrier() { pimpl_->barrier(); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::barrier() { pimpl_->barrier(); }
|
||||
|
||||
MSCCLPP_API_CPP Bootstrap::~Bootstrap() { pimpl_->close(); }
|
||||
MSCCLPP_API_CPP TcpBootstrap::~TcpBootstrap() { pimpl_->close(); }
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <mscclpp/config.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <sstream>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "debug.h"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap) {
|
||||
Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap) : bootstrap_(bootstrap) {
|
||||
rankToHash_.resize(bootstrap->getNranks());
|
||||
auto hostHash = getHostHash();
|
||||
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
|
||||
@@ -47,10 +47,10 @@ cudaStream_t Communicator::Impl::getIpcStream() { return ipcStream_; }
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<Bootstrap> bootstrap)
|
||||
: pimpl(std::make_unique<Impl>(bootstrap)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<BaseBootstrap> Communicator::bootstrapper() { return pimpl->bootstrap_; }
|
||||
MSCCLPP_API_CPP std::shared_ptr<Bootstrap> Communicator::bootstrap() { return pimpl->bootstrap_; }
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) {
|
||||
return RegisteredMemory(
|
||||
@@ -61,7 +61,7 @@ struct MemorySender : public Setuppable {
|
||||
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
|
||||
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override {
|
||||
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
|
||||
bootstrap->send(memory_.serialize(), remoteRank_, tag_);
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, in
|
||||
struct MemoryReceiver : public Setuppable {
|
||||
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override {
|
||||
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override {
|
||||
std::vector<char> data;
|
||||
bootstrap->recv(data, remoteRank_, tag_);
|
||||
memoryPromise_.set_value(RegisteredMemory::deserialize(data));
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "config.hpp"
|
||||
#include <mscclpp/config.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
Config Config::instance_;
|
||||
|
||||
@@ -171,7 +171,7 @@ void IBConnection::flush() {
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
void IBConnection::beginSetup(std::shared_ptr<Bootstrap> bootstrap) {
|
||||
std::vector<char> ibQpTransport;
|
||||
std::copy_n(reinterpret_cast<char*>(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport));
|
||||
std::copy_n(reinterpret_cast<char*>(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport));
|
||||
@@ -179,7 +179,7 @@ void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
|
||||
}
|
||||
|
||||
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap) {
|
||||
void IBConnection::endSetup(std::shared_ptr<Bootstrap> bootstrap) {
|
||||
std::vector<char> ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport));
|
||||
bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
|
||||
|
||||
|
||||
@@ -82,9 +82,9 @@ const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transpo
|
||||
|
||||
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
|
||||
|
||||
void Setuppable::beginSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
void Setuppable::beginSetup(std::shared_ptr<Bootstrap>) {}
|
||||
|
||||
void Setuppable::endSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
void Setuppable::endSetup(std::shared_ptr<Bootstrap>) {}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ struct Communicator::Impl {
|
||||
std::vector<std::shared_ptr<Setuppable>> toSetup_;
|
||||
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
|
||||
cudaStream_t ipcStream_;
|
||||
std::shared_ptr<BaseBootstrap> bootstrap_;
|
||||
std::shared_ptr<Bootstrap> bootstrap_;
|
||||
std::vector<uint64_t> rankToHash_;
|
||||
|
||||
Impl(std::shared_ptr<BaseBootstrap> bootstrap);
|
||||
Impl(std::shared_ptr<Bootstrap> bootstrap);
|
||||
|
||||
~Impl();
|
||||
|
||||
|
||||
@@ -71,9 +71,9 @@ class IBConnection : public ConnectionBase {
|
||||
|
||||
void flush() override;
|
||||
|
||||
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
|
||||
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override;
|
||||
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
|
||||
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -69,12 +69,12 @@ MSCCLPP_API_CPP SmDevice2DeviceSemaphore::SmDevice2DeviceSemaphore(Communicator&
|
||||
remoteInboundSemaphoreIdsRegMem_ =
|
||||
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
|
||||
INFO(MSCCLPP_INIT, "Creating a direct semaphore for CudaIPC transport from %d to %d",
|
||||
communicator.bootstrapper()->getRank(), connection->remoteRank());
|
||||
communicator.bootstrap()->getRank(), connection->remoteRank());
|
||||
isRemoteInboundSemaphoreIdSet_ = true;
|
||||
} else if (AllIBTransports.has(connection->transport())) {
|
||||
// We don't need to really with any of the IB transports, since the values will be local
|
||||
INFO(MSCCLPP_INIT, "Creating a direct semaphore for IB transport from %d to %d",
|
||||
communicator.bootstrapper()->getRank(), connection->remoteRank());
|
||||
communicator.bootstrap()->getRank(), connection->remoteRank());
|
||||
isRemoteInboundSemaphoreIdSet_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,9 +391,9 @@ int main(int argc, const char* argv[]) {
|
||||
|
||||
try {
|
||||
if (rank == 0) printf("Initializing MSCCL++\n");
|
||||
auto bootstrapper = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
|
||||
bootstrapper->initialize(ip_port);
|
||||
mscclpp::Communicator comm(bootstrapper);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
|
||||
bootstrap->initialize(ip_port);
|
||||
mscclpp::Communicator comm(bootstrap);
|
||||
mscclpp::ProxyService channelService(comm);
|
||||
|
||||
if (rank == 0) printf("Initializing data for allgather test\n");
|
||||
@@ -422,19 +422,19 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
int tmp[16];
|
||||
// A simple barrier
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
bootstrap->allGather(tmp, sizeof(int));
|
||||
if (rank == 0) printf("Successfully checked the correctness\n");
|
||||
|
||||
// Perf test
|
||||
int iterwithoutcudagraph = 10;
|
||||
if (rank == 0) printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
bootstrap->allGather(tmp, sizeof(int));
|
||||
for (int i = 0; i < iterwithoutcudagraph; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
bootstrap->allGather(tmp, sizeof(int));
|
||||
|
||||
// cudaGraph Capture
|
||||
int cudagraphiter = 10;
|
||||
@@ -462,7 +462,7 @@ int main(int argc, const char* argv[]) {
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
|
||||
cudagraphiter);
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
bootstrap->allGather(tmp, sizeof(int));
|
||||
double t0, t1, ms, time_in_us;
|
||||
t0 = getTime();
|
||||
for (int i = 0; i < cudagraphlaunch; ++i) {
|
||||
@@ -475,7 +475,7 @@ int main(int argc, const char* argv[]) {
|
||||
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
|
||||
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
|
||||
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
bootstrap->allGather(tmp, sizeof(int));
|
||||
|
||||
if (rank == 0) printf("Stopping MSCCL++ proxy threads\n");
|
||||
channelService.stopProxy();
|
||||
|
||||
@@ -236,7 +236,7 @@ int main(int argc, char* argv[]) {
|
||||
CUCHECK(cudaSetDevice(cudaNum));
|
||||
|
||||
if (rank == 0) printf("Initializing MSCCL++\n");
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
|
||||
mscclpp::UniqueId uniqueId;
|
||||
if (rank == 0) uniqueId = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&uniqueId, sizeof(uniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "config.hpp"
|
||||
#include <mscclpp/config.hpp>
|
||||
|
||||
#include "mp_unit_tests.hpp"
|
||||
|
||||
void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
|
||||
void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
|
||||
std::vector<int> tmp(bootstrap->getNranks(), 0);
|
||||
tmp[bootstrap->getRank()] = bootstrap->getRank() + 1;
|
||||
bootstrap->allGather(tmp.data(), sizeof(int));
|
||||
@@ -15,9 +16,9 @@ void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::BaseBootstra
|
||||
}
|
||||
}
|
||||
|
||||
void BootstrapTest::bootstrapTestBarrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) { bootstrap->barrier(); }
|
||||
void BootstrapTest::bootstrapTestBarrier(std::shared_ptr<mscclpp::Bootstrap> bootstrap) { bootstrap->barrier(); }
|
||||
|
||||
void BootstrapTest::bootstrapTestSendRecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
|
||||
void BootstrapTest::bootstrapTestSendRecv(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
if (bootstrap->getRank() == i) continue;
|
||||
int msg1 = (bootstrap->getRank() + 1) * 3;
|
||||
@@ -43,14 +44,14 @@ void BootstrapTest::bootstrapTestSendRecv(std::shared_ptr<mscclpp::BaseBootstrap
|
||||
}
|
||||
}
|
||||
|
||||
void BootstrapTest::bootstrapTestAll(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
|
||||
void BootstrapTest::bootstrapTestAll(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
|
||||
bootstrapTestAllGather(bootstrap);
|
||||
bootstrapTestBarrier(bootstrap);
|
||||
bootstrapTestSendRecv(bootstrap);
|
||||
}
|
||||
|
||||
TEST_F(BootstrapTest, WithId) {
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
@@ -59,14 +60,14 @@ TEST_F(BootstrapTest, WithId) {
|
||||
}
|
||||
|
||||
TEST_F(BootstrapTest, WithIpPortPair) {
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
bootstrap->initialize(gEnv->args["ip_port"]);
|
||||
bootstrapTestAll(bootstrap);
|
||||
}
|
||||
|
||||
TEST_F(BootstrapTest, ResumeWithId) {
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
@@ -76,13 +77,13 @@ TEST_F(BootstrapTest, ResumeWithId) {
|
||||
|
||||
TEST_F(BootstrapTest, ResumeWithIpPortPair) {
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
bootstrap->initialize(gEnv->args["ip_port"]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BootstrapTest, ExitBeforeConnect) {
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
bootstrap->createUniqueId();
|
||||
}
|
||||
|
||||
@@ -94,7 +95,7 @@ TEST_F(BootstrapTest, TimeoutWithId) {
|
||||
mscclpp::Timer timer;
|
||||
|
||||
// All ranks initialize a bootstrap with their own id (will hang)
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
mscclpp::UniqueId id = bootstrap->createUniqueId();
|
||||
|
||||
try {
|
||||
@@ -108,9 +109,9 @@ TEST_F(BootstrapTest, TimeoutWithId) {
|
||||
ASSERT_LT(timer.elapsed(), 1100000);
|
||||
}
|
||||
|
||||
class MPIBootstrap : public mscclpp::BaseBootstrap {
|
||||
class MPIBootstrap : public mscclpp::Bootstrap {
|
||||
public:
|
||||
MPIBootstrap() : BaseBootstrap() {}
|
||||
MPIBootstrap() : Bootstrap() {}
|
||||
int getRank() override {
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
@@ -19,10 +19,10 @@ void CommunicatorTestBase::SetUp() {
|
||||
ibTransport = ibIdToTransport(rankToLocalRank(gEnv->rank));
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(rankToLocalRank(gEnv->rank)));
|
||||
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
|
||||
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
|
||||
mscclpp::UniqueId id;
|
||||
if (gEnv->rank < numRanksToUse) {
|
||||
bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, numRanksToUse);
|
||||
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, numRanksToUse);
|
||||
if (gEnv->rank == 0) id = bootstrap->createUniqueId();
|
||||
}
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
@@ -63,14 +63,14 @@ void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, mscc
|
||||
localMemory = communicator->registerMemory(buff, buffSize, transport);
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemories;
|
||||
for (int remoteRank : remoteRanks) {
|
||||
if (remoteRank != communicator->bootstrapper()->getRank()) {
|
||||
if (remoteRank != communicator->bootstrap()->getRank()) {
|
||||
communicator->sendMemoryOnSetup(localMemory, remoteRank, tag);
|
||||
futureRemoteMemories[remoteRank] = communicator->recvMemoryOnSetup(remoteRank, tag);
|
||||
}
|
||||
}
|
||||
communicator->setup();
|
||||
for (int remoteRank : remoteRanks) {
|
||||
if (remoteRank != communicator->bootstrapper()->getRank()) {
|
||||
if (remoteRank != communicator->bootstrap()->getRank()) {
|
||||
remoteMemories[remoteRank] = futureRemoteMemories[remoteRank].get();
|
||||
}
|
||||
}
|
||||
@@ -166,10 +166,10 @@ TEST_F(CommunicatorTest, BasicWrite) {
|
||||
if (gEnv->rank >= numRanksToUse) return;
|
||||
|
||||
deviceBufferInit();
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
// polling until it becomes ready
|
||||
bool ready = false;
|
||||
@@ -181,7 +181,7 @@ TEST_F(CommunicatorTest, BasicWrite) {
|
||||
FAIL() << "Polling is stuck.";
|
||||
}
|
||||
} while (!ready);
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
}
|
||||
|
||||
__global__ void kernelWaitSemaphores(mscclpp::Host2DeviceSemaphore::DeviceHandle* deviceSemaphores, int rank,
|
||||
@@ -201,10 +201,10 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
|
||||
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2DeviceSemaphore>(*communicator.get(), conn)});
|
||||
}
|
||||
communicator->setup();
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
deviceBufferInit();
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
auto deviceSemaphoreHandles = mscclpp::allocSharedCuda<mscclpp::Host2DeviceSemaphore::DeviceHandle>(gEnv->worldSize);
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
@@ -214,7 +214,7 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
|
||||
1, cudaMemcpyHostToDevice);
|
||||
}
|
||||
}
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
|
||||
|
||||
@@ -228,7 +228,7 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
ASSERT_TRUE(testWriteCorrectness());
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
}
|
||||
|
||||
TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
|
||||
@@ -242,10 +242,10 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
|
||||
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2HostSemaphore>(*communicator.get(), conn)});
|
||||
}
|
||||
communicator->setup();
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
deviceBufferInit();
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
|
||||
|
||||
@@ -274,5 +274,5 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
|
||||
}
|
||||
|
||||
ASSERT_TRUE(testWriteCorrectness());
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ void IbPeerToPeerTest::SetUp() {
|
||||
|
||||
if (gEnv->rank < 2) {
|
||||
// This test needs only two ranks
|
||||
bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, 2);
|
||||
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, 2);
|
||||
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
|
||||
}
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
@@ -42,13 +42,13 @@ class MultiProcessTest : public ::testing::Test {
|
||||
|
||||
class BootstrapTest : public MultiProcessTest {
|
||||
protected:
|
||||
void bootstrapTestAllGather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
|
||||
void bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
|
||||
|
||||
void bootstrapTestBarrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
|
||||
void bootstrapTestBarrier(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
|
||||
|
||||
void bootstrapTestSendRecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
|
||||
void bootstrapTestSendRecv(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
|
||||
|
||||
void bootstrapTestAll(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap);
|
||||
void bootstrapTestAll(std::shared_ptr<mscclpp::Bootstrap> bootstrap);
|
||||
|
||||
// Each test case should finish within 30 seconds.
|
||||
mscclpp::Timer bootstrapTestTimer{30};
|
||||
@@ -76,7 +76,7 @@ class IbPeerToPeerTest : public IbTestBase {
|
||||
void stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled,
|
||||
unsigned int immData);
|
||||
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
|
||||
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
|
||||
std::shared_ptr<mscclpp::IbCtx> ibCtx;
|
||||
mscclpp::IbQp* qp;
|
||||
const mscclpp::IbMr* mr;
|
||||
|
||||
@@ -17,8 +17,8 @@ void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
|
||||
void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels,
|
||||
bool useIbOnly, void* sendBuff, size_t sendBuffBytes,
|
||||
void* recvBuff, size_t recvBuffBytes) {
|
||||
const int rank = communicator->bootstrapper()->getRank();
|
||||
const int worldSize = communicator->bootstrapper()->getNranks();
|
||||
const int rank = communicator->bootstrap()->getRank();
|
||||
const int worldSize = communicator->bootstrap()->getNranks();
|
||||
const bool isInPlace = (recvBuff == nullptr);
|
||||
mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport);
|
||||
|
||||
@@ -274,7 +274,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
channelService->stopProxy();
|
||||
}
|
||||
@@ -312,7 +312,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
|
||||
<<<1, 1024>>>(buff.get(), putPacketBuffer.get(), getPacketBuffer.get(), gEnv->rank, 2, nTries, nullptr);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
// Measure latency
|
||||
mscclpp::Timer timer;
|
||||
@@ -320,7 +320,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
|
||||
<<<1, 1024>>>(buff.get(), putPacketBuffer.get(), getPacketBuffer.get(), gEnv->rank, 2, nTries, nullptr);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
communicator->bootstrapper()->barrier();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
if (gEnv->rank == 0) {
|
||||
std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)nTries << " us/iter\n";
|
||||
|
||||
@@ -17,8 +17,8 @@ void SmChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
|
||||
|
||||
void SmChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
|
||||
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes) {
|
||||
const int rank = communicator->bootstrapper()->getRank();
|
||||
const int worldSize = communicator->bootstrapper()->getNranks();
|
||||
const int rank = communicator->bootstrap()->getRank();
|
||||
const int worldSize = communicator->bootstrap()->getNranks();
|
||||
const bool isInPlace = (outputBuff == nullptr);
|
||||
mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc | ibTransport;
|
||||
|
||||
|
||||
@@ -225,7 +225,7 @@ double BaseTestEngine::benchTime() {
|
||||
return deltaSec;
|
||||
}
|
||||
|
||||
void BaseTestEngine::barrier() { this->comm_->bootstrapper()->barrier(); }
|
||||
void BaseTestEngine::barrier() { this->comm_->bootstrap()->barrier(); }
|
||||
|
||||
void BaseTestEngine::runTest() {
|
||||
// warm-up for large size
|
||||
@@ -326,7 +326,7 @@ void BaseTestEngine::runTest() {
|
||||
}
|
||||
|
||||
void BaseTestEngine::bootstrap() {
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(args_.rank, args_.totalRanks);
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(args_.rank, args_.totalRanks);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
@@ -9,25 +9,25 @@
|
||||
class LocalCommunicatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
bootstrap = std::make_shared<mscclpp::Bootstrap>(0, 1);
|
||||
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(0, 1);
|
||||
comm = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
|
||||
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
|
||||
std::shared_ptr<mscclpp::Communicator> comm;
|
||||
};
|
||||
|
||||
class MockSetuppable : public mscclpp::Setuppable {
|
||||
public:
|
||||
MOCK_METHOD(void, beginSetup, (std::shared_ptr<mscclpp::BaseBootstrap> bootstrap), (override));
|
||||
MOCK_METHOD(void, endSetup, (std::shared_ptr<mscclpp::BaseBootstrap> bootstrap), (override));
|
||||
MOCK_METHOD(void, beginSetup, (std::shared_ptr<mscclpp::Bootstrap> bootstrap), (override));
|
||||
MOCK_METHOD(void, endSetup, (std::shared_ptr<mscclpp::Bootstrap> bootstrap), (override));
|
||||
};
|
||||
|
||||
TEST_F(LocalCommunicatorTest, OnSetup) {
|
||||
auto mockSetuppable = std::make_shared<MockSetuppable>();
|
||||
comm->onSetup(mockSetuppable);
|
||||
EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast<mscclpp::BaseBootstrap>(bootstrap)));
|
||||
EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast<mscclpp::BaseBootstrap>(bootstrap)));
|
||||
EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
|
||||
EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
|
||||
comm->setup();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user