Change device handle interfaces & others (#142)

* Changed device handle interfaces
* Changed proxy service interfaces
* Move device code into separate files
* Fixed FIFO polling issues
* Add configuration arguments in several interface functions

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: root <root@a100-saemal0.qxveptpukjsuthqvv514inp03c.gx.internal.cloudapp.net>
This commit is contained in:
Saeed Maleki
2023-08-16 05:00:56 -07:00
committed by GitHub
parent 4865b2017b
commit 8d1b984bed
59 changed files with 1271 additions and 1036 deletions

View File

@@ -17,8 +17,8 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall,-Wextra")
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
# clang-format targets
include(${PROJECT_SOURCE_DIR}/cmake/AddClangFormatTargets.cmake)
# Format targets
include(${PROJECT_SOURCE_DIR}/cmake/AddFormatTargets.cmake)
# Options
option(ENABLE_TRACE "Enable tracing" OFF)

View File

@@ -67,7 +67,7 @@ mscclpp::Communicator comm(bootstrap);
// Setup connections here using `comm`
...
// Construct the default proxy
mscclpp::ProxyService proxyService(comm);
mscclpp::ProxyService proxyService();
// Start the proxy
proxyService.startProxy();
// Run the user application, i.e., launch GPU kernels here
@@ -80,7 +80,7 @@ While the default implementation already enables any kinds of communication, MSC
```cpp
// Proxy FIFO is obtained from mscclpp::Proxy on the host and copied to the device.
__device__ mscclpp::DeviceProxyFifo fifo;
__device__ mscclpp::FifoDeviceHandle fifo;
__global__ void gpuKernel() {
...
// Only one thread is needed for the followings

View File

@@ -1,18 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Add targets to run clang-format
find_program(CLANG_FORMAT clang-format)
if(CLANG_FORMAT)
message(STATUS "Found clang-format: ${CLANG_FORMAT}")
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test)
add_custom_target(check-format ALL
COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
)
add_custom_target(format
COMMAND ${CLANG_FORMAT} -style=file -i `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
)
else()
message(STATUS "clang-format not found.")
endif()

View File

@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Add targets to run clang-format and black
add_custom_target(check-format)
add_custom_target(format)
find_program(CLANG_FORMAT clang-format)
if(CLANG_FORMAT)
message(STATUS "Found clang-format: ${CLANG_FORMAT}")
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test)
add_custom_target(check-format-cpp ALL
COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
)
add_dependencies(check-format check-format-cpp)
add_custom_target(format-cpp
COMMAND ${CLANG_FORMAT} -style=file -i `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
)
add_dependencies(format format-cpp)
else()
message(STATUS "clang-format not found.")
endif()
find_program(BLACK black)
if (BLACK)
message(STATUS "Found black: ${BLACK}")
add_custom_target(check-format-py
COMMAND ${BLACK} --config ${PROJECT_SOURCE_DIR}/pyproject.toml --check ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test
)
add_dependencies(check-format check-format-py)
add_custom_target(format-py
COMMAND ${BLACK} --config ${PROJECT_SOURCE_DIR}/pyproject.toml ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test
)
add_dependencies(format format-py)
else()
message(STATUS, "black not found.")
endif()

View File

@@ -1,27 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_CONFIG_H_
#define MSCCLPP_CONFIG_H_
namespace mscclpp {
class Config {
public:
int bootstrapConnectionTimeout = 30;
static Config* getInstance();
int getBootstrapConnectionTimeoutConfig();
void setBootstrapConnectionTimeoutConfig(int timeout);
private:
Config() = default;
Config(const Config&) = delete;
Config& operator=(const Config&) = delete;
static Config instance_;
};
} // namespace mscclpp
#endif // end include guard

View File

@@ -61,11 +61,13 @@ class TcpBootstrap : public Bootstrap {
/// Initialize the @ref TcpBootstrap with a given unique ID.
/// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with.
void initialize(UniqueId uniqueId);
/// @param timeoutSec The connection timeout in seconds.
void initialize(UniqueId uniqueId, int64_t timeoutSec = 30);
/// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port".
/// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port".
void initialize(const std::string& ifIpPortTrio);
/// @param timeoutSec The connection timeout in seconds.
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30);
/// Return the rank of the process.
int getRank() override;
@@ -384,7 +386,7 @@ class Connection {
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
/// Flush any pending writes to the remote process.
virtual void flush() = 0;
virtual void flush(int64_t timeoutUsec = 3e7) = 0;
/// Get the rank of the remote process.
///
@@ -533,8 +535,14 @@ class Communicator {
/// @param remoteRank The rank of the remote process.
/// @param tag The tag of the connection for identifying it.
/// @param transport The type of transport to be used.
/// @param ibMaxCqSize The maximum number of completion queue entries for IB. Unused if transport is not IB.
/// @param ibMaxCqPollNum The maximum number of completion queue entries to poll for IB. Unused if transport is not
/// IB.
/// @param ibMaxSendWr The maximum number of outstanding send work requests for IB. Unused if transport is not IB.
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport);
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024,
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64);
/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///

View File

@@ -7,90 +7,25 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <mscclpp/fifo_device.hpp>
#include <mscclpp/poll.hpp>
#define MSCCLPP_PROXY_FIFO_SIZE 128
namespace mscclpp {
/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy.
///
/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push
/// ProxyTrigger elements and a single host proxy thread consumes these work elements.
///
struct alignas(16) ProxyTrigger {
uint64_t fst, snd;
};
/// A concurrent FIFO where multiple device threads can push work elements and a single host proxy thread consumes them.
///
/// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost
/// infinity. There are two copies of the tail, one on the device, @ref DeviceProxyFifo::tailReplica, and another on the
/// host, namely, hostTail. The host always has the "true" tail and occasionally pushes it to the copy on the device.
/// Therefore, most of the time, the device has a stale version. The invariants are: tailReplica <= hostTail <= head.
/// The @ref push() function increments head, hostTail is updated in @ref HostProxyFifo::pop(), and it occasionally
/// flushes it to tailReplica via @ref HostProxyFifo::flushTail().
///
/// Duplicating the tail is a good idea because the FIFO is large enough, and we do not need frequent updates for the
/// tail as there is usually enough space for device threads to push their work into.
///
struct DeviceProxyFifo {
#ifdef __CUDACC__
/// Push a trigger to the FIFO.
///
/// @param trigger The trigger to push.
/// @return The new head of the FIFO.
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
// Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to
// write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but
// for the second condition we need to read CPU memory.
// As volatile access is slow, we first check using the bare pointer and then use the volatile pointer if the
// condition is not met.
if (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *(this->tailReplica)) {
OR_POLL_MAYBE_JAILBREAK(curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica),
*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0,
1000000);
}
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]);
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
return curFifoHead;
}
/// Wait until there is a place in the FIFO to push a trigger.
///
/// @param curFifoHead The current head of the FIFO.
__forceinline__ __device__ void sync(uint64_t curFifoHead) {
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
// to wait for cudaMemcpy to be done.
OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]) != 0,
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000);
}
#endif // __CUDACC__
/// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`.
ProxyTrigger* triggers;
/// Replica of the FIFO tail that is allocated on device.
uint64_t* tailReplica;
/// The FIFO head. Allocated on the device and only accessed by the device.
uint64_t* head;
};
/// A class representing a host proxy FIFO that can consume work elements pushed by device threads.
class HostProxyFifo {
class Fifo {
public:
/// Constructs a new @ref HostProxyFifo object.
HostProxyFifo();
/// Constructs a new @ref Fifo object.
/// @param size The number of entires in the FIFO.
Fifo(int size = 128);
/// Destroys the @ref HostProxyFifo object.
~HostProxyFifo();
/// Destroys the @ref Fifo object.
~Fifo();
/// Polls the FIFO for a trigger.
///
/// @param trigger A pointer to the trigger to be filled.
void poll(ProxyTrigger* trigger);
/// Returns @ref ProxyTrigger which is the trigger at the head of fifo.
ProxyTrigger poll();
/// Pops a trigger from the FIFO.
void pop();
@@ -100,10 +35,14 @@ class HostProxyFifo {
/// @param sync If true, waits for the flush to complete before returning.
void flushTail(bool sync = false);
/// Returns a @ref DeviceProxyFifo object representing the device FIFO.
/// Return the FIFO size.
/// @return The FIFO size.
int size() const;
/// Returns a @ref FifoDeviceHandle object representing the device FIFO.
///
/// @return A @ref DeviceProxyFifo object representing the device FIFO.
DeviceProxyFifo deviceFifo();
/// @return A @ref FifoDeviceHandle object representing the device FIFO.
FifoDeviceHandle deviceHandle();
private:
struct Impl;

View File

@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_FIFO_DEVICE_HPP_
#define MSCCLPP_FIFO_DEVICE_HPP_
#include "poll.hpp"
namespace mscclpp {
/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy.
///
/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push
/// ProxyTrigger elements and a single host proxy thread consumes these work elements.
///
/// Do not use the most significant bit of @ref snd as it is reserved for memory consistency purposes
struct alignas(16) ProxyTrigger {
uint64_t fst, snd;
};
/// A concurrent FIFO where multiple device threads can push work elements and a single host proxy thread consumes them.
///
/// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost
/// infinity. There are two copies of the tail, one on the device, @ref FifoDeviceHandle::tailReplica, and another on
/// the host, namely, hostTail. The host always has the "true" tail and occasionally pushes it to the copy on the
/// device. Therefore, most of the time, the device has a stale version. The invariants are: tailReplica <= hostTail <=
/// head. The @ref push() function increments head, hostTail is updated in @ref Fifo::pop(), and it occasionally flushes
/// it to tailReplica via @ref Fifo::flushTail().
///
/// Duplicating the tail is a good idea because the FIFO is large enough, and we do not need frequent updates for the
/// tail as there is usually enough space for device threads to push their work into.
///
struct FifoDeviceHandle {
#ifdef __CUDACC__
/// Push a trigger to the FIFO.
///
/// @param trigger The trigger to push.
/// @return The new head of the FIFO.
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
// make the last bit intentionally non-zero so that we can safely poll. Don't worry, we will change it back in host
// side
trigger.snd ^= ((uint64_t)1 << (uint64_t)63);
// Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to
// write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but
// for the second condition we need to read CPU memory.
// As volatile access is slow, we first check using the bare pointer and then use the volatile pointer if the
// condition is not met.
if (curFifoHead >= size + *(this->tailReplica)) {
OR_POLL_MAYBE_JAILBREAK(curFifoHead >= size + *((volatile uint64_t*)this->tailReplica),
*(volatile uint64_t*)&this->triggers[curFifoHead % size] != 0, 1000000);
}
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % size]);
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
return curFifoHead;
}
/// Wait until there is a place in the FIFO to push a trigger.
///
/// @param curFifoHead The current head of the FIFO.
__forceinline__ __device__ void sync(uint64_t curFifoHead) {
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
// to wait for cudaMemcpy to be done.
OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % size]) != 0,
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000);
}
#endif // __CUDACC__
/// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`.
ProxyTrigger* triggers;
/// Replica of the FIFO tail that is allocated on device.
uint64_t* tailReplica;
/// The FIFO head. Allocated on the device and only accessed by the device.
uint64_t* head;
/// The FIFO size.
int size;
};
} // namespace mscclpp
#endif // MSCCLPP_FIFO_DEVICE_HPP_

View File

@@ -4,6 +4,8 @@
#ifndef MSCCLPP_PACKET_HPP_
#define MSCCLPP_PACKET_HPP_
#include "poll.hpp"
namespace mscclpp {
/// LL (low latency) protocol packet.
@@ -42,17 +44,24 @@ union LLPacket {
"r"((uint32_t)(val >> 32)), "r"(flag));
}
/// Helper of @ref read().
/// @param flag The flag to read.
/// @param data The 8-byte data read.
/// @return True if the flag is not equal to the given flag.
__forceinline__ __device__ bool readOnce(uint32_t flag, uint2& data) {
uint32_t flag1, flag2;
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(data.x), "=r"(flag1), "=r"(data.y), "=r"(flag2)
: "l"(v));
return (flag1 != flag) || (flag2 != flag);
}
/// Read 8 bytes of data from the packet.
/// @param flag The flag to read.
/// @return The 8-byte data read.
__forceinline__ __device__ uint2 read(uint32_t flag) {
uint2 data;
uint32_t flag1, flag2;
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(data.x), "=r"(flag1), "=r"(data.y), "=r"(flag2)
: "l"(v));
} while ((flag1 != flag) || (flag2 != flag));
POLL_MAYBE_JAILBREAK(readOnce(flag, data), 100000000);
return data;
}
@@ -80,6 +89,7 @@ __forceinline__ __device__ void putPackets(void* dst, uint64_t dstOffset, void*
__forceinline__ __device__ void getPackets(void* dst, uint64_t dstOffset, void* src, uint64_t srcOffset,
uint64_t dstBytes, uint32_t threadId, uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
// TODO(saemal): this is not matching sm_channel get method.
LLPacket* srcBase = (LLPacket*)((char*)src + srcOffset);
uint2* dstBase = (uint2*)((char*)dst + dstOffset);
size_t nElem = dstBytes / sizeof(uint2);

View File

@@ -6,14 +6,8 @@
#ifdef __CUDACC__
#ifndef NDEBUG
// TODO(chhwang): https://github.com/microsoft/mscclpp/issues/99
#define POLL_PRINT_ON_STUCK(__cond)
// #include <stdio.h>
// #define POLL_PRINT_ON_STUCK(__cond) do { printf("mscclpp: spin is stuck. condition: " #__cond "\n"); } while (0);
#else // NDEBUG
#define POLL_PRINT_ON_STUCK(__cond)
#endif // NDEBUG
extern __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
const char *__function) __THROW;
// If a spin is stuck, escape from it and set status to 1.
#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \
@@ -22,7 +16,6 @@
__status = 0; \
while (__cond) { \
if (__spin_cnt++ == __max_spin_cnt) { \
POLL_PRINT_ON_STUCK(__cond); \
__status = 1; \
break; \
} \
@@ -30,31 +23,31 @@
} while (0);
// If a spin is stuck, print a warning and keep spinning.
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \
uint64_t __spin_cnt = 0; \
while (__cond) { \
if (__spin_cnt++ == __max_spin_cnt) { \
POLL_PRINT_ON_STUCK(__cond); \
} \
} \
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \
uint64_t __spin_cnt = 0; \
while (__cond) { \
if (__spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
} while (0);
// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2
// this is specially useful when __cond1 is faster to check
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
do { \
uint64_t __spin_cnt = 0; \
while (true) { \
if (!(__cond1)) { \
break; \
} else if (!(__cond2)) { \
break; \
} \
if (__spin_cnt++ == __max_spin_cnt) { \
POLL_PRINT_ON_STUCK(__cond); \
} \
} \
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
do { \
uint64_t __spin_cnt = 0; \
while (true) { \
if (!(__cond1)) { \
break; \
} else if (!(__cond2)) { \
break; \
} \
if (__spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
} while (0);
#endif // __CUDACC__

View File

@@ -28,7 +28,7 @@ class Proxy {
void start();
void stop();
HostProxyFifo& fifo();
Fifo& fifo();
private:
struct Impl;

View File

@@ -5,18 +5,12 @@
#define MSCCLPP_PROXY_CHANNEL_HPP_
#include <mscclpp/core.hpp>
#include <mscclpp/fifo.hpp>
#include <mscclpp/proxy.hpp>
#include <mscclpp/proxy_channel_device.hpp>
#include <mscclpp/semaphore.hpp>
namespace mscclpp {
using SemaphoreId = uint32_t;
/// Numeric ID of @ref RegisteredMemory. @ref ProxyService has an internal array indexed by these handles mapping to the
/// actual.
using MemoryId = uint32_t;
struct ProxyChannel;
/// Base class for proxy services. Proxy services are used to proxy data between devices.
@@ -32,13 +26,17 @@ class BaseProxyService {
class ProxyService : public BaseProxyService {
public:
/// Constructor.
/// @param communicator The communicator to use.
ProxyService(Communicator& communicator);
ProxyService();
/// Add a semaphore to the proxy service.
/// Build and add a semaphore to the proxy service.
/// @param connection The connection associated with the semaphore.
/// @return The ID of the semaphore.
SemaphoreId addSemaphore(std::shared_ptr<Connection> connection);
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
/// Add a semaphore to the proxy service.
/// @param semaphore The semaphore to be added
/// @return The ID of the semaphore.
SemaphoreId addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore);
/// Register a memory region with the proxy service.
/// @param memory The memory region to register.
@@ -53,7 +51,7 @@ class ProxyService : public BaseProxyService {
/// Get a proxy channel by semaphore ID.
/// @param id The ID of the semaphore.
/// @return The proxy channel.
ProxyChannel deviceChannel(SemaphoreId id);
ProxyChannel proxyChannel(SemaphoreId id);
/// Start the proxy service.
void startProxy();
@@ -62,7 +60,6 @@ class ProxyService : public BaseProxyService {
void stopProxy();
private:
Communicator& communicator_;
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
Proxy proxy_;
@@ -73,170 +70,44 @@ class ProxyService : public BaseProxyService {
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw);
};
using TriggerType = uint64_t;
const TriggerType TriggerData = 0x1; // Trigger a data transfer.
const TriggerType TriggerFlag = 0x2; // Trigger a signaling.
const TriggerType TriggerSync = 0x4; // Trigger a flush.
#define MSCCLPP_BITS_SIZE 32
#define MSCCLPP_BITS_OFFSET 32
#define MSCCLPP_BITS_REGMEM_HANDLE 8
#define MSCCLPP_BITS_TYPE 3
#define MSCCLPP_BITS_CONNID 10
/// Basic structure of each work element in the FIFO.
union ChannelTrigger {
ProxyTrigger value;
// The summation of number of bits must be 128 or less.
struct {
// First 64 bits: value[0]
uint64_t size : MSCCLPP_BITS_SIZE;
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
// Second 64 bits: value[1]
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
uint64_t type : MSCCLPP_BITS_TYPE;
uint64_t chanId : MSCCLPP_BITS_CONNID;
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE -
MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
} fields;
#ifdef __CUDACC__
/// Default constructor.
__device__ ChannelTrigger() {}
/// Copy constructor.
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
/// Constructor.
/// @param type The type of the trigger.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param bytes The bytes of the transfer.
/// @param semaphoreId The ID of the semaphore.
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t bytes, int semaphoreId) {
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes);
value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
<< MSCCLPP_BITS_REGMEM_HANDLE) +
src)
<< MSCCLPP_BITS_OFFSET) +
dstOffset);
}
#endif // __CUDACC__
};
/// Proxy channel.
struct ProxyChannel {
// Use DeviceHandle<ProxyChannel> in device code.
typedef ProxyChannel DeviceHandle;
ProxyChannel() = default;
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, DeviceProxyFifo fifo);
ProxyChannel(const ProxyChannel& other) = default;
ProxyChannel& operator=(ProxyChannel& other) = default;
#ifdef __CUDACC__
/// Push a @ref TriggerData to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size) {
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
}
/// Push a @ref TriggerData to the FIFO.
/// @param dst The destination memory region.
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
put(dst, offset, src, offset, size);
}
/// Push a @ref TriggerFlag to the FIFO.
__forceinline__ __device__ void signal() {
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value);
}
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size) {
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
}
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
putWithSignal(dst, offset, src, offset, size);
}
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size) {
uint64_t curFifoHead = fifo_.push(
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_)
.value);
fifo_.sync(curFifoHead);
}
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
putWithSignalAndFlush(dst, offset, src, offset, size);
}
/// Push a @ref TriggerSync to the FIFO.
__forceinline__ __device__ void flush() {
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value);
fifo_.sync(curFifoHead);
}
/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { semaphore_.wait(); }
#endif // __CUDACC__
private:
SemaphoreId semaphoreId_;
Host2DeviceSemaphore::DeviceHandle semaphore_;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
DeviceProxyFifo fifo_;
FifoDeviceHandle fifo_;
public:
ProxyChannel() = default;
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, FifoDeviceHandle fifo);
ProxyChannel(const ProxyChannel& other) = default;
ProxyChannel& operator=(ProxyChannel& other) = default;
/// Device-side handle for @ref ProxyChannel.
using DeviceHandle = ProxyChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the ProxyChannel is not released when using the returned handle.
///
DeviceHandle deviceHandle() const;
};
/// Simple proxy channel with a single destination and source memory region.
struct SimpleProxyChannel {
// Use DeviceHandle<SimpleProxyChannel> in device code.
typedef SimpleProxyChannel DeviceHandle;
private:
ProxyChannel proxyChan_;
MemoryId dst_;
MemoryId src_;
public:
/// Default constructor.
SimpleProxyChannel() = default;
@@ -256,69 +127,16 @@ struct SimpleProxyChannel {
/// Assignment operator.
SimpleProxyChannel& operator=(SimpleProxyChannel& other) = default;
#ifdef __CUDACC__
/// Push a @ref TriggerData to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.put(dst_, dstOffset, src_, srcOffset, size);
}
/// Device-side handle for @ref SimpleProxyChannel.
using DeviceHandle = SimpleProxyChannelDeviceHandle;
/// Push a @ref TriggerData to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }
/// Push a @ref TriggerFlag to the FIFO.
__forceinline__ __device__ void signal() { proxyChan_.signal(); }
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
}
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); }
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
}
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
putWithSignalAndFlush(offset, offset, size);
}
/// Push a @ref TriggerSync to the FIFO.
__forceinline__ __device__ void flush() { proxyChan_.flush(); }
/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { proxyChan_.wait(); }
#endif // __CUDACC__
ProxyChannel proxyChan_;
MemoryId dst_;
MemoryId src_;
/// Returns the device-side handle.
///
/// User should make sure the SimpleProxyChannel is not released when using the returned handle.
///
DeviceHandle deviceHandle() const;
};
template <>
DeviceHandle<ProxyChannel> deviceHandle(ProxyChannel&& proxyChannel);
template <>
DeviceHandle<SimpleProxyChannel> deviceHandle(SimpleProxyChannel&& simpleProxyChannel);
} // namespace mscclpp
#endif // MSCCLPP_PROXY_CHANNEL_HPP_

View File

@@ -0,0 +1,227 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_
#define MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_
#include "fifo_device.hpp"
#include "semaphore_device.hpp"
namespace mscclpp {
using SemaphoreId = uint32_t;
/// Numeric ID of @ref RegisteredMemory. @ref ProxyService has an internal array indexed by these handles mapping to the
/// actual.
using MemoryId = uint32_t;
using TriggerType = uint64_t;
const TriggerType TriggerData = 0x1; // Trigger a data transfer.
const TriggerType TriggerFlag = 0x2; // Trigger a signaling.
const TriggerType TriggerSync = 0x4; // Trigger a flush.
#define MSCCLPP_BITS_SIZE 32
#define MSCCLPP_BITS_OFFSET 32
#define MSCCLPP_BITS_REGMEM_HANDLE 8
#define MSCCLPP_BITS_TYPE 3
#define MSCCLPP_BITS_CONNID 10
#define MSCCLPP_BITS_FIFO_RESERVED 1
/// Basic structure of each work element in the FIFO.
union ChannelTrigger {
ProxyTrigger value;
// The summation of number of bits must be 128 or less.
struct {
// First 64 bits: value[0]
uint64_t size : MSCCLPP_BITS_SIZE;
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
// Second 64 bits: value[1]
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
uint64_t type : MSCCLPP_BITS_TYPE;
uint64_t chanId : MSCCLPP_BITS_CONNID;
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_TYPE -
MSCCLPP_BITS_CONNID - MSCCLPP_BITS_FIFO_RESERVED); // ensure 64-bit alignment
uint64_t reserved : MSCCLPP_BITS_FIFO_RESERVED;
} fields;
#ifdef __CUDACC__
/// Default constructor.
__forceinline__ __device__ ChannelTrigger() {}
/// Copy constructor.
__forceinline__ __device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
/// Constructor.
/// @param type The type of the trigger.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param bytes The bytes of the transfer.
/// @param semaphoreId The ID of the semaphore.
__forceinline__ __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t bytes, int semaphoreId) {
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes);
value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
<< MSCCLPP_BITS_REGMEM_HANDLE) +
src)
<< MSCCLPP_BITS_OFFSET) +
dstOffset);
}
#endif // __CUDACC__
};
struct ProxyChannelDeviceHandle {
SemaphoreId semaphoreId_;
Host2DeviceSemaphoreDeviceHandle semaphore_;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
FifoDeviceHandle fifo_;
#ifdef __CUDACC__
/// Push a @ref TriggerData to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size) {
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
}
/// Push a @ref TriggerData to the FIFO.
/// @param dst The destination memory region.
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
put(dst, offset, src, offset, size);
}
/// Push a @ref TriggerFlag to the FIFO.
__forceinline__ __device__ void signal() {
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value);
}
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size) {
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
}
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
putWithSignal(dst, offset, src, offset, size);
}
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size) {
uint64_t curFifoHead = fifo_.push(
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_)
.value);
fifo_.sync(curFifoHead);
}
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
putWithSignalAndFlush(dst, offset, src, offset, size);
}
/// Push a @ref TriggerSync to the FIFO.
__forceinline__ __device__ void flush() {
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value);
fifo_.sync(curFifoHead);
}
/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { semaphore_.wait(); }
#endif // __CUDACC__
};
struct SimpleProxyChannelDeviceHandle {
ProxyChannelDeviceHandle proxyChan_;
MemoryId dst_;
MemoryId src_;
#ifdef __CUDACC__
/// Push a @ref TriggerData to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.put(dst_, dstOffset, src_, srcOffset, size);
}
/// Push a @ref TriggerData to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }
/// Push a @ref TriggerFlag to the FIFO.
__forceinline__ __device__ void signal() { proxyChan_.signal(); }
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
}
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); }
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
}
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
putWithSignalAndFlush(offset, offset, size);
}
/// Push a @ref TriggerSync to the FIFO.
__forceinline__ __device__ void flush() { proxyChan_.flush(); }
/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { proxyChan_.wait(); }
#endif // __CUDACC__
};
} // namespace mscclpp
#endif // MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_

View File

@@ -8,6 +8,7 @@
#include <mscclpp/core.hpp>
#include <mscclpp/cuda_utils.hpp>
#include <mscclpp/poll.hpp>
#include <mscclpp/semaphore_device.hpp>
namespace mscclpp {
@@ -81,18 +82,7 @@ class Host2DeviceSemaphore : public BaseSemaphore<CudaDeleter, std::default_dele
void signal();
/// Device-side handle for @ref Host2DeviceSemaphore.
struct DeviceHandle {
#ifdef __CUDACC__
/// Wait for the host to signal.
__forceinline__ __device__ void wait() {
(*expectedInboundSemaphoreId) += 1;
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), 1000000);
}
#endif // __CUDACC__
uint64_t* inboundSemaphoreId;
uint64_t* expectedInboundSemaphoreId;
};
using DeviceHandle = Host2DeviceSemaphoreDeviceHandle;
/// Returns the device-side handle.
DeviceHandle deviceHandle();
@@ -133,50 +123,7 @@ class SmDevice2DeviceSemaphore : public BaseSemaphore<CudaDeleter, CudaDeleter>
SmDevice2DeviceSemaphore() = default;
/// Device-side handle for @ref SmDevice2DeviceSemaphore.
struct DeviceHandle {
#ifdef __CUDACC__
/// Wait for the remote device to signal.
__forceinline__ __device__ void wait() {
(*expectedInboundSemaphoreId) += 1;
POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 1000000);
}
/// Signal the remote device.
///
/// This function guarantees that all the memory operation before this function is completed before the remote
/// semaphore is signaled.
///
__forceinline__ __device__ void signal() {
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
// `outboundSemaphoreId` is visible.
__threadfence_system();
semaphoreIncrement();
*remoteInboundSemaphoreId = semaphoreGetLocal();
}
/// Signal the remote device for copied packets.
///
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
/// completion of copies.
///
__forceinline__ __device__ void signalPacket() {
semaphoreIncrement();
*remoteInboundSemaphoreId = semaphoreGetLocal();
}
/// Increase the counter of the local semaphore.
__forceinline__ __device__ void semaphoreIncrement() { *outboundSemaphoreId += 1; }
/// Get the value of the local semaphore.
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; }
#endif // __CUDACC__
volatile uint64_t* inboundSemaphoreId;
uint64_t* outboundSemaphoreId;
volatile uint64_t* remoteInboundSemaphoreId;
uint64_t* expectedInboundSemaphoreId;
};
using DeviceHandle = SmDevice2DeviceSemaphoreDeviceHandle;
/// Returns the device-side handle.
DeviceHandle deviceHandle() const;

View File

@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_SEMAPHORE_DEVICE_HPP_
#define MSCCLPP_SEMAPHORE_DEVICE_HPP_
#include "poll.hpp"
namespace mscclpp {
/// Device-side handle for @ref Host2DeviceSemaphore.
struct Host2DeviceSemaphoreDeviceHandle {
#ifdef __CUDACC__
/// Wait for the host to signal.
__forceinline__ __device__ void wait() {
(*expectedInboundSemaphoreId) += 1;
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), 100000000);
}
#endif // __CUDACC__
uint64_t* inboundSemaphoreId;
uint64_t* expectedInboundSemaphoreId;
};
/// Device-side handle for @ref SmDevice2DeviceSemaphore.
struct SmDevice2DeviceSemaphoreDeviceHandle {
#ifdef __CUDACC__
/// Wait for the remote device to signal.
__forceinline__ __device__ void wait() {
(*expectedInboundSemaphoreId) += 1;
POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 100000000);
}
/// Signal the remote device.
///
/// This function guarantees that all the memory operation before this function is completed before the remote
/// semaphore is signaled.
///
__forceinline__ __device__ void signal() {
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
// `outboundSemaphoreId` is visible.
__threadfence_system();
semaphoreIncrement();
*remoteInboundSemaphoreId = semaphoreGetLocal();
}
/// Signal the remote device for copied packets.
///
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
/// completion of copies.
///
__forceinline__ __device__ void signalPacket() {
semaphoreIncrement();
*remoteInboundSemaphoreId = semaphoreGetLocal();
}
/// Increase the counter of the local semaphore.
__forceinline__ __device__ void semaphoreIncrement() { *outboundSemaphoreId += 1; }
/// Get the value of the local semaphore.
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; }
#endif // __CUDACC__
volatile uint64_t* inboundSemaphoreId;
uint64_t* outboundSemaphoreId;
volatile uint64_t* remoteInboundSemaphoreId;
uint64_t* expectedInboundSemaphoreId;
};
} // namespace mscclpp
#endif // MSCCLPP_SEMAPHORE_DEVICE_HPP_

View File

@@ -5,8 +5,8 @@
#define MSCCLPP_SM_CHANNEL_HPP_
#include <mscclpp/core.hpp>
#include <mscclpp/packet.hpp>
#include <mscclpp/semaphore.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <type_traits>
namespace mscclpp {
@@ -31,305 +31,8 @@ struct SmChannel {
SmChannel(std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
void* getPacketBuffer = nullptr);
struct DeviceHandle {
SmDevice2DeviceSemaphore::DeviceHandle semaphore_;
void* src_;
void* dst_;
void* getPacketBuffer_;
private:
#ifdef __CUDACC__
/// Helper for aligned data type access.
/// @tparam T The data type.
template <typename T>
struct Element {
static constexpr bool is4B = (sizeof(T) == 4);
static constexpr bool is8B = (sizeof(T) == 8);
static constexpr bool is4Bx2 =
(std::is_same<T, int2>::value || std::is_same<T, uint2>::value || std::is_same<T, float2>::value);
static constexpr bool is4Bx4 =
(std::is_same<T, int4>::value || std::is_same<T, uint4>::value || std::is_same<T, float4>::value);
static constexpr bool is8Bx2 =
(std::is_same<T, longlong2>::value || std::is_same<T, ulonglong2>::value || std::is_same<T, double2>::value);
// Note: we do not support long2 and ulong2 as their size may differ on different platforms.
static constexpr bool isValid = (is4B || is8B || is4Bx2 || is4Bx4 || is8Bx2);
/// Load an element from DRAM.
///
/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param v The value to be loaded.
/// @param p The address of the value to be loaded.
///
static __forceinline__ __device__ void load(T& v, const T* p) {
if constexpr (is4B) {
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory");
} else if constexpr (is8B) {
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory");
} else if constexpr (is4Bx2) {
asm volatile("ld.volatile.global.v2.u32 {%0,%1}, [%2];" : "=r"(v.x), "=r"(v.y) : "l"(p) : "memory");
} else if constexpr (is4Bx4) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z)
: "l"(p)
: "memory");
} else if constexpr (is8Bx2) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
}
static_assert(isValid, "Unsupported type T");
}
/// Write an element on DRAM.
///
/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param p The address of the value to be written.
/// @param v The value to be written.
///
static __forceinline__ __device__ void store(T* p, const T& v) {
if constexpr (is4B) {
asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory");
} else if constexpr (is8B) {
asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory");
} else if constexpr (is4Bx2) {
asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" : : "l"(p), "r"(v.x), "r"(v.y) : "memory");
} else if constexpr (is4Bx4) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
:
: "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z)
: "memory");
} else if constexpr (is8Bx2) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
static_assert(isValid, "Unsupported type T");
}
/// Copy aligned elements from the source memory to the destination memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
/// elements.
///
/// @param dst The destination address.
/// @param src The source address.
/// @param numElems The number of elements to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different
/// from the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
static __forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId,
uint32_t numThreads) {
T reg;
for (size_t i = threadId; i < numElems; i += numThreads) {
// Load to register first.
load(reg, src + i);
store(dst + i, reg);
}
}
};
#endif // __CUDACC__
public:
#ifdef __CUDACC__
/// Load a value from the remote memory.
/// @tparam T The type of the value to be loaded.
/// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T).
/// @return The value loaded.
template <typename T>
__forceinline__ __device__ T read(uint64_t index) {
T v;
Element<T>::load(v, (T*)dst_ + index);
return v;
}
/// Write a value to the remote memory.
/// @tparam T The type of the value to be written.
/// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T).
/// @param v The value to be written.
template <typename T>
__forceinline__ __device__ void write(uint64_t index, const T& v) {
Element<T>::store((T*)dst_ + index, v);
}
/// Copy aligned data from the source memory to the destination memory.
///
/// This function is a warpper of Element<T>::copy(). Unlike Element<T>::copy(), this function can copy remainder
/// bytes when @p CopyRemainder is true. Still, the copying bytes must be a multiple of 4.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src.
/// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 4, bool CopyRemainder = true>
__forceinline__ __device__ void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
static_assert(Alignment == 4 || Alignment == 8 || Alignment % 16 == 0, "Unsupported alignment");
using Type =
typename std::conditional<Alignment == 4, int,
typename std::conditional<Alignment == 8, long long, longlong2>::type>::type;
int* dstInt = reinterpret_cast<int*>(dst);
int* srcInt = reinterpret_cast<int*>(src);
const uintptr_t dstPtr = reinterpret_cast<uintptr_t>(dst);
const uintptr_t srcPtr = reinterpret_cast<uintptr_t>(src);
const uint64_t numInt = bytes / sizeof(int);
Type* dstElem = reinterpret_cast<Type*>((dstPtr + sizeof(Type) - 1) / sizeof(Type) * sizeof(Type));
Type* srcElem = reinterpret_cast<Type*>((srcPtr + sizeof(Type) - 1) / sizeof(Type) * sizeof(Type));
uint64_t nFirstInt = (reinterpret_cast<uintptr_t>(dstElem) - dstPtr) / sizeof(int);
if (CopyRemainder) {
// Copy the remainder integers at the beginning.
Element<int>::copy(dstInt, srcInt, nFirstInt, threadId, numThreads);
}
// Copy elements.
constexpr uint64_t nIntPerElem = sizeof(Type) / sizeof(int);
uint64_t nElem = (numInt - nFirstInt) / nIntPerElem;
Element<Type>::copy(dstElem, srcElem, nElem, threadId, numThreads);
if (CopyRemainder && nIntPerElem > 1) {
// Copy the remainder integers at the end.
uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem;
Element<int>::copy(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt,
threadId, numThreads);
}
}
/// Copy data from the local memory to the remote memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
uint32_t numThreads) {
copy<Alignment, CopyRemainder>((char*)dst_ + dstOffset, (char*)src_ + srcOffset, bytes, threadId, numThreads);
}
/// Copy data from the remote memory to the local memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void get(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
uint32_t numThreads) {
// Note that `dst` and `src` are swapped for `get()`.
copy<Alignment, CopyRemainder>((char*)src_ + srcOffset, (char*)dst_ + dstOffset, bytes, threadId, numThreads);
}
/// Copy data from the local memory to the remote memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void put(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
put<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
}
/// Copy data from the remote memory to the local memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void get(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
get<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
}
/// Construct @ref LLPacket from the data in the local memory and write it on the remote memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of packets.
///
/// @param dstOffset The offset in bytes of the remote address.
/// @param srcOffset The offset in bytes of the local address.
/// @param bytes Bytes of the data to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
__forceinline__ __device__ void putPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
mscclpp::putPackets(dst_, dstOffset, src_, srcOffset, bytes, threadId, numThreads, flag);
}
/// Retrieve data from @ref LLPacket in the local packet buffer and write it on the local memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @param dstOffset The offset in bytes of the local memory.
/// @param srcOffset The offset in bytes of the local packet buffer.
/// @param bytes Bytes of the data to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
__forceinline__ __device__ void getPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
mscclpp::getPackets(src_, dstOffset, getPacketBuffer_, srcOffset, bytes, threadId, numThreads, flag);
}
/// Signal the remote semaphore.
///
/// This function guarantees that all the memory operation before this function is completed before the remote
/// semaphore is signaled.
///
__forceinline__ __device__ void signal() { semaphore_.signal(); }
/// Signal the remote semaphore for copied packets.
///
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
/// completion of copies.
///
__forceinline__ __device__ void signalPacket() { semaphore_.signalPacket(); }
/// Increase the counter of the local semaphore.
__forceinline__ __device__ void semaphoreIncrement() { semaphore_.semaphoreIncrement(); }
/// Read the counter of the local semaphore.
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
/// Wait for the remote semaphore to send a signal.
__forceinline__ __device__ void wait() { semaphore_.wait(); }
#endif // __CUDACC__
};
/// Device-side handle for @ref SmChannel.
using DeviceHandle = SmChannelDeviceHandle;
/// Returns the device-side handle.
///

View File

@@ -0,0 +1,336 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_SM_CHANNEL_DEVICE_HPP_
#define MSCCLPP_SM_CHANNEL_DEVICE_HPP_
#include "packet.hpp"
#include "poll.hpp"
#include "semaphore_device.hpp"
namespace mscclpp {
#ifdef __CUDACC__
namespace Element {
/// Load an element from DRAM.
///
/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param v The value to be loaded.
/// @param p The address of the value to be loaded.
///
template <typename T>
__forceinline__ __device__ void load(T& v, const T* p) {
// We should only use the specialized functions.
__assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
}
/// Write an element on DRAM.
///
/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param p The address of the value to be written.
/// @param v The value to be written.
///
template <typename T>
__forceinline__ __device__ void store(T* p, const T& v) {
// We should only use the specialized functions.
__assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
}
/// Copy aligned elements from the source memory to the destination memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
/// elements.
///
/// @param dst The destination address.
/// @param src The source address.
/// @param numElems The number of elements to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different
/// from the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <typename T>
__forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId, uint32_t numThreads) {
T reg;
for (size_t i = threadId; i < numElems; i += numThreads) {
// Load to register first.
load(reg, src + i);
store(dst + i, reg);
}
}
template <>
__forceinline__ __device__ void load<long long>(long long& v, const long long* p) {
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory");
}
template <>
__forceinline__ __device__ void store<long long>(long long* p, const long long& v) {
asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory");
}
template <>
__forceinline__ __device__ void load<int>(int& v, const int* p) {
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory");
}
template <>
__forceinline__ __device__ void store<int>(int* p, const int& v) {
asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory");
}
template <>
__forceinline__ __device__ void load<longlong2>(longlong2& v, const longlong2* p) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
}
template <>
__forceinline__ __device__ void store<longlong2>(longlong2* p, const longlong2& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
template <>
__forceinline__ __device__ void load<int4>(int4& v, const int4* p) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z)
: "l"(p)
: "memory");
}
template <>
__forceinline__ __device__ void store<int4>(int4* p, const int4& v) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
:
: "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z)
: "memory");
}
} // namespace Element
#endif // __CUDACC__
/// Channel for accessing peer memory directly from SM.
struct SmChannelDeviceHandle {
SmDevice2DeviceSemaphoreDeviceHandle semaphore_;
void* src_;
void* dst_;
void* getPacketBuffer_;
#ifdef __CUDACC__
/// Load a value from the remote memory.
/// @tparam T The type of the value to be loaded.
/// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T).
/// @return The value loaded.
template <typename T>
__forceinline__ __device__ T read(uint64_t index) {
T v;
Element::load<T>(v, (T*)dst_ + index);
return v;
}
/// Write a value to the remote memory.
/// @tparam T The type of the value to be written.
/// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T).
/// @param v The value to be written.
template <typename T>
__forceinline__ __device__ void write(uint64_t index, const T& v) {
Element::store<T>((T*)dst_ + index, v);
}
/// this is a helper for copy function
template <typename T, bool CopyRemainder = true>
__forceinline__ __device__ void copy_helper(void* dst, void* src, uint64_t bytes, uint32_t threadId,
uint32_t numThreads) {
int* dstInt = reinterpret_cast<int*>(dst);
int* srcInt = reinterpret_cast<int*>(src);
const uintptr_t dstPtr = reinterpret_cast<uintptr_t>(dst);
const uintptr_t srcPtr = reinterpret_cast<uintptr_t>(src);
const uint64_t numInt = bytes / sizeof(int);
T* dstElem = reinterpret_cast<T*>((dstPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
T* srcElem = reinterpret_cast<T*>((srcPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
uint64_t nFirstInt = (reinterpret_cast<uintptr_t>(dstElem) - dstPtr) / sizeof(int);
if (CopyRemainder) {
// Copy the remainder integers at the beginning.
Element::copy<int>(dstInt, srcInt, nFirstInt, threadId, numThreads);
}
// Copy elements.
constexpr uint64_t nIntPerElem = sizeof(T) / sizeof(int);
uint64_t nElem = (numInt - nFirstInt) / nIntPerElem;
Element::copy<T>(dstElem, srcElem, nElem, threadId, numThreads);
if (CopyRemainder && nIntPerElem > 1) {
// Copy the remainder integers at the end.
uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem;
Element::copy<int>(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt,
threadId, numThreads);
}
}
/// Copy aligned data from the source memory to the destination memory.
///
/// This function is a warpper of Element<T>::copy(). Unlike Element<T>::copy(), this function can copy remainder
/// bytes when @p CopyRemainder is true. Still, the 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src.
/// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
if (Alignment == 4) {
copy_helper<int, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else if (Alignment == 8) {
copy_helper<long long, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else if (Alignment == 16) {
copy_helper<longlong2, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else {
static_assert(Alignment == 4 || Alignment == 8 || Alignment == 16, "Unsupported alignment");
}
}
/// Copy data from the local memory to the remote memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
uint32_t numThreads) {
copy<Alignment, CopyRemainder>((char*)dst_ + dstOffset, (char*)src_ + srcOffset, bytes, threadId, numThreads);
}
/// Copy data from the remote memory to the local memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void get(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
uint32_t numThreads) {
// Note that `dst` and `src` are swapped for `get()`.
copy<Alignment, CopyRemainder>((char*)src_ + srcOffset, (char*)dst_ + dstOffset, bytes, threadId, numThreads);
}
/// Copy data from the local memory to the remote memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void put(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
put<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
}
/// Copy data from the remote memory to the local memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
__forceinline__ __device__ void get(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
get<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
}
/// Construct @ref LLPacket from the data in the local memory and write it on the remote memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of packets.
///
/// @param dstOffset The offset in bytes of the remote address.
/// @param srcOffset The offset in bytes of the local address.
/// @param bytes Bytes of the data to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
__forceinline__ __device__ void putPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
mscclpp::putPackets(dst_, dstOffset, src_, srcOffset, bytes, threadId, numThreads, flag);
}
/// Retrieve data from @ref LLPacket in the local packet buffer and write it on the local memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @param dstOffset The offset in bytes of the local memory.
/// @param srcOffset The offset in bytes of the local packet buffer.
/// @param bytes Bytes of the data to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
__forceinline__ __device__ void getPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
mscclpp::getPackets(src_, dstOffset, getPacketBuffer_, srcOffset, bytes, threadId, numThreads, flag);
}
/// Signal the remote semaphore.
///
/// This function guarantees that all the memory operation before this function is completed before the remote
/// semaphore is signaled.
///
__forceinline__ __device__ void signal() { semaphore_.signal(); }
/// Signal the remote semaphore for copied packets.
///
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
/// completion of copies.
///
__forceinline__ __device__ void signalPacket() { semaphore_.signalPacket(); }
/// Increase the counter of the local semaphore.
__forceinline__ __device__ void semaphoreIncrement() { semaphore_.semaphoreIncrement(); }
/// Read the counter of the local semaphore.
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
/// Wait for the remote semaphore to send a signal.
__forceinline__ __device__ void wait() { semaphore_.wait(); }
#endif // __CUDACC__
};
} // namespace mscclpp
#endif // MSCCLPP_SM_CHANNEL_DEVICE_HPP_

View File

@@ -17,6 +17,7 @@ struct Timer {
~Timer();
/// Returns the elapsed time in milliseconds.
int64_t elapsed() const;
void set(int timeout);

View File

@@ -1,16 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <mscclpp/config.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_config(nb::module_& m) {
nb::class_<Config>(m, "Config")
.def_static("get_instance", &Config::getInstance, nb::rv_policy::reference)
.def("get_bootstrap_connection_timeout_config", &Config::getBootstrapConnectionTimeoutConfig)
.def("set_bootstrap_connection_timeout_config", &Config::setBootstrapConnectionTimeoutConfig);
}

View File

@@ -17,8 +17,8 @@ extern void register_proxy_channel(nb::module_& m);
extern void register_sm_channel(nb::module_& m);
extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_config(nb::module_& m);
extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
@@ -62,9 +62,10 @@ void register_core(nb::module_& m) {
nb::arg("nRanks"))
.def("create_unique_id", &TcpBootstrap::createUniqueId)
.def("get_unique_id", &TcpBootstrap::getUniqueId)
.def("initialize", (void (TcpBootstrap::*)(UniqueId)) & TcpBootstrap::initialize, nb::arg("uniqueId"))
.def("initialize", (void (TcpBootstrap::*)(const std::string&)) & TcpBootstrap::initialize,
nb::arg("ifIpPortTrio"));
.def("initialize", (void (TcpBootstrap::*)(UniqueId, int64_t)) & TcpBootstrap::initialize, nb::arg("uniqueId"),
nb::arg("timeoutSec") = 30)
.def("initialize", (void (TcpBootstrap::*)(const std::string&, int64_t)) & TcpBootstrap::initialize,
nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
nb::enum_<Transport>(m, "Transport")
.value("Unknown", Transport::Unknown)
@@ -118,7 +119,7 @@ void register_core(nb::module_& m) {
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
},
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
.def("flush", &Connection::flush)
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
.def("remote_rank", &Connection::remoteRank)
.def("tag", &Connection::tag)
.def("transport", &Connection::transport)
@@ -139,7 +140,8 @@ void register_core(nb::module_& m) {
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("transport"))
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
.def("setup", &Communicator::setup);
}
@@ -149,7 +151,7 @@ NB_MODULE(_mscclpp, m) {
register_sm_channel(m);
register_fifo(m);
register_semaphore(m);
register_config(m);
register_utils(m);
register_core(m);
register_numa(m);
}

View File

@@ -1,13 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import mscclpp
import argparse
import multiprocessing as mp
import logging
import torch
import multiprocessing as mp
import sys
import mscclpp
import torch
IB_TRANSPORTS = [
mscclpp.Transport.IB0,
mscclpp.Transport.IB1,
@@ -19,15 +20,19 @@ IB_TRANSPORTS = [
mscclpp.Transport.IB7,
]
# Use to hold the sm channels so they don't get garbage collected
sm_channels = []
def setup_connections(comm, rank, world_size, element_size, proxy_service):
simple_proxy_channels = []
sm_semaphores = []
connections = []
remote_memories = []
memory = torch.zeros(element_size, dtype=torch.int32)
memory = memory.to("cuda")
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
ptr = memory.data_ptr()
size = memory.numel() * memory.element_size()
reg_mem = comm.register_memory(ptr, size, transport_flag)
@@ -42,15 +47,26 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
remote_memories.append(remote_mem)
comm.setup()
# Create simple proxy channels
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
comm.setup()
return simple_proxy_channels
# Create sm channels
for i, conn in enumerate(connections):
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
sm_semaphores.append(sm_chan)
comm.setup()
for i, conn in enumerate(sm_semaphores):
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
sm_channels.append(sm_chan)
return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels]
def run(rank, args):
@@ -60,7 +76,7 @@ def run(rank, args):
boot = mscclpp.TcpBootstrap.create(rank, world_size)
boot.initialize(args.if_ip_port_trio)
comm = mscclpp.Communicator(boot)
proxy_service = mscclpp.ProxyService(comm)
proxy_service = mscclpp.ProxyService()
logging.info("Rank: %d, setting up connections", rank)
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)

View File

@@ -1,12 +0,0 @@
import mscclpp
def main():
config = mscclpp.Config.get_instance()
config.set_bootstrap_connection_timeout_config(15)
timeout = config.get_bootstrap_connection_timeout_config()
assert timeout == 15
if __name__ == "__main__":
main()

View File

@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import mscclpp
import argparse
import time
import mscclpp
def main(args):
if args.root:

View File

@@ -1,7 +1,10 @@
import mscclpp
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import mscclpp
def main():
timer = mscclpp.Timer()

View File

@@ -11,15 +11,20 @@ using namespace mscclpp;
void register_fifo(nb::module_& m) {
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
nb::class_<DeviceProxyFifo>(m, "DeviceProxyFifo")
.def_rw("triggers", &DeviceProxyFifo::triggers)
.def_rw("tail_replica", &DeviceProxyFifo::tailReplica)
.def_rw("head", &DeviceProxyFifo::head);
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
.def_rw("triggers", &FifoDeviceHandle::triggers)
.def_rw("tail_replica", &FifoDeviceHandle::tailReplica)
.def_rw("head", &FifoDeviceHandle::head)
.def_rw("size", &FifoDeviceHandle::size)
.def_prop_ro("raw", [](const FifoDeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<HostProxyFifo>(m, "HostProxyFifo")
.def(nb::init<>())
.def("poll", &HostProxyFifo::poll, nb::arg("trigger"))
.def("pop", &HostProxyFifo::pop)
.def("flush_tail", &HostProxyFifo::flushTail, nb::arg("sync") = false)
.def("device_fifo", &HostProxyFifo::deviceFifo);
nb::class_<Fifo>(m, "Fifo")
.def(nb::init<int>(), nb::arg("size") = 128)
.def("poll", &Fifo::poll)
.def("pop", &Fifo::pop)
.def("flush_tail", &Fifo::flushTail, nb::arg("sync") = false)
.def("size", &Fifo::size)
.def("device_handle", &Fifo::deviceHandle);
}

View File

@@ -1,10 +1,31 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ._mscclpp import *
import os as _os
from ._mscclpp import (
Communicator,
Connection,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
numa,
ProxyService,
RegisteredMemory,
SimpleProxyChannel,
SmChannel,
SmDevice2DeviceSemaphore,
TcpBootstrap,
Transport,
TransportFlags,
)
def get_include():
"""Return the directory that contains the MSCCL++ headers."""
return _os.path.join(_os.path.dirname(__file__), "include")
def get_lib():
"""Return the directory that contains the MSCCL++ headers."""
return _os.path.join(_os.path.dirname(__file__), "lib")

13
python/numa_py.cpp Normal file
View File

@@ -0,0 +1,13 @@
#include <nanobind/nanobind.h>
namespace nb = nanobind;
namespace mscclpp {
int getDeviceNumaNode(int cudaDev);
void numaBind(int node);
}; // namespace mscclpp
void register_numa(nb::module_ &m) {
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
sub_m.def("numa_bind", &mscclpp::numaBind);
}

View File

@@ -16,22 +16,40 @@ void register_proxy_channel(nb::module_& m) {
.def("stop_proxy", &BaseProxyService::stopProxy);
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
.def(nb::init<Communicator&>(), nb::arg("comm"))
.def(nb::init<>())
.def("start_proxy", &ProxyService::startProxy)
.def("stop_proxy", &ProxyService::stopProxy)
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("connection"))
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
.def("device_channel", &ProxyService::deviceChannel, nb::arg("id"));
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
nb::class_<ProxyChannel>(m, "ProxyChannel")
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, DeviceProxyFifo>(), nb::arg("semaphoreId"),
nb::arg("semaphore"), nb::arg("fifo"));
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, FifoDeviceHandle>(), nb::arg("semaphoreId"),
nb::arg("semaphore"), nb::arg("fifo"))
.def("device_handle", &ProxyChannel::deviceHandle);
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src"))
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"));
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"))
.def("device_handle", &SimpleProxyChannel::deviceHandle);
m.def("device_handle", &deviceHandle<ProxyChannel>, nb::arg("proxyChannel"));
m.def("device_handle", &deviceHandle<SimpleProxyChannel>, nb::arg("simpleProxyChannel"));
nb::class_<SimpleProxyChannel::DeviceHandle>(m, "SimpleProxyChannelDeviceHandle")
.def(nb::init<>())
.def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_)
.def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_)
.def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_)
.def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
};

View File

@@ -20,7 +20,10 @@ void register_semaphore(nb::module_& m) {
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
@@ -38,5 +41,8 @@ void register_semaphore(nb::module_& m) {
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
.def_prop_ro("raw", [](const SmDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
}

View File

@@ -13,11 +13,23 @@ using namespace mscclpp;
void register_sm_channel(nb::module_& m) {
nb::class_<SmChannel> smChannel(m, "SmChannel");
smChannel
.def(nb::init<std::shared_ptr<SmDevice2DeviceSemaphore>, RegisteredMemory, void*, void*>(), nb::arg("semaphore"),
nb::arg("dst"), nb::arg("src"), nb::arg("getPacketBuffer"))
.def("__init__",
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); })
.def("__init__",
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
uintptr_t src, uintptr_t get_packet_buffer) {
new (smChannel) SmChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
})
.def("device_handle", &SmChannel::deviceHandle);
nb::class_<SmChannel::DeviceHandle>(smChannel, "DeviceHandle");
m.def("device_handle", &deviceHandle<SmChannel>, nb::arg("smChannel"));
nb::class_<SmChannel::DeviceHandle>(m, "SmChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_)
.def_rw("src_", &SmChannel::DeviceHandle::src_)
.def_rw("dst_", &SmChannel::DeviceHandle::dst_)
.def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_)
.def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
};

View File

@@ -4,7 +4,6 @@
#include <sys/resource.h>
#include <cstring>
#include <mscclpp/config.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/errors.hpp>
#include <sstream>
@@ -59,9 +58,9 @@ class TcpBootstrap::Impl {
public:
Impl(int rank, int nRanks);
~Impl();
void initialize(const UniqueId& uniqueId);
void initialize(const std::string& ifIpPortTrio);
void establishConnections();
void initialize(const UniqueId& uniqueId, int64_t timeoutSec);
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec);
void establishConnections(int64_t timeoutSec);
UniqueId createUniqueId();
UniqueId getUniqueId() const;
int getRank();
@@ -133,15 +132,15 @@ int TcpBootstrap::Impl::getRank() { return rank_; }
int TcpBootstrap::Impl::getNranks() { return nRanks_; }
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId) {
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) {
netInit("", "");
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
establishConnections();
establishConnections(timeoutSec);
}
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) {
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) {
// first check if it is a trio
int nColons = 0;
for (auto c : ifIpPortTrio) {
@@ -167,7 +166,7 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) {
bootstrapCreateRoot();
}
establishConnections();
establishConnections(timeoutSec);
}
TcpBootstrap::Impl::~Impl() {
@@ -308,8 +307,8 @@ void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface)
} \
} while (0);
void TcpBootstrap::Impl::establishConnections() {
const int64_t connectionTimeoutUs = (int64_t)Config::getInstance()->getBootstrapConnectionTimeoutConfig() * 1000000;
void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) {
const int64_t connectionTimeoutUs = timeoutSec * 1000000;
Timer timer;
SocketAddress nextAddr;
ExtInfo info;
@@ -317,6 +316,10 @@ void TcpBootstrap::Impl::establishConnections() {
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_);
auto getLeftTime = [&]() {
if (connectionTimeoutUs < 0) {
// no timeout: always return a large number
return int64_t(1e9);
}
int64_t timeout = connectionTimeoutUs - timer.elapsed();
if (timeout <= 0) throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout);
return timeout;
@@ -489,9 +492,13 @@ MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag)
MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); }
MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId, int64_t timeoutSec) {
pimpl_->initialize(uniqueId, timeoutSec);
}
MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair) { pimpl_->initialize(ipPortPair); }
MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair, int64_t timeoutSec) {
pimpl_->initialize(ipPortPair, timeoutSec);
}
MSCCLPP_API_CPP void TcpBootstrap::barrier() { pimpl_->barrier(); }

View File

@@ -13,7 +13,6 @@
#include <string.h>
#include <unistd.h>
#include <mscclpp/config.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/utils.hpp>
#include <sstream>

View File

@@ -94,7 +94,11 @@ MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSe
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
}
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) {
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport,
int ibMaxCqSize /*=1024*/,
int ibMaxCqPollNum /*=1*/,
int ibMaxSendWr /*=8192*/,
int ibMaxWrPerSend /*=64*/) {
std::shared_ptr<ConnectionBase> conn;
if (transport == Transport::CudaIpc) {
// sanity check: make sure the IPC connection is being made within a node
@@ -111,7 +115,8 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank,
pimpl->rankToHash_[remoteRank]);
} else if (AllIBTransports.has(transport)) {
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr,
ibMaxWrPerSend, *pimpl);
conn = ibConn;
INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created",
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],

View File

@@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <mscclpp/config.hpp>
namespace mscclpp {
Config Config::instance_;
Config* Config::getInstance() { return &instance_; }
int Config::getBootstrapConnectionTimeoutConfig() { return bootstrapConnectionTimeout; }
void Config::setBootstrapConnectionTimeoutConfig(int timeout) { bootstrapConnectionTimeout = timeout; }
} // namespace mscclpp

View File

@@ -70,21 +70,26 @@ void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset,
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
}
void CudaIpcConnection::flush() {
void CudaIpcConnection::flush(int64_t timeoutUsec) {
if (timeoutUsec >= 0) {
INFO(MSCCLPP_P2P, "CudaIpcConnection flush: timeout is not supported, ignored");
}
AvoidCudaGraphCaptureGuard guard;
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream_));
// npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
INFO(MSCCLPP_P2P, "CudaIpcConnection flushing connection to remote rank %d", remoteRank());
}
// IBConnection
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl)
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
int maxWrPerSend, Communicator::Impl& commImpl)
: ConnectionBase(remoteRank, tag),
transport_(transport),
remoteTransport_(Transport::Unknown),
numSignaledSends(0),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
qp = commImpl.getIbContext(transport)->createQp();
qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend);
dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(
dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl));
validateTransport(dummyAtomicSourceMem_, transport);
@@ -144,7 +149,7 @@ void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint6
oldValue, newValue);
}
void IBConnection::flush() {
void IBConnection::flush(int64_t timeoutUsec) {
Timer timer;
while (numSignaledSends) {
int wcNum = qp->pollCq();
@@ -153,8 +158,8 @@ void IBConnection::flush() {
}
auto elapsed = timer.elapsed();
if (elapsed > MSCCLPP_POLLING_WAIT) {
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " +
if ((timeoutUsec >= 0) && (elapsed * 1e3 > timeoutUsec)) {
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e3) + " seconds. Expected " +
std::to_string(numSignaledSends) + " signals",
ErrorCode::InternalError);
}
@@ -168,6 +173,7 @@ void IBConnection::flush() {
}
}
}
INFO(MSCCLPP_NET, "IBConnection flushing connection to remote rank %d", remoteRank());
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}

View File

@@ -1,20 +1,18 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <emmintrin.h>
#include <mscclpp/cuda_utils.hpp>
#include <mscclpp/fifo.hpp>
#include <stdexcept>
#include "api.h"
namespace mscclpp {
struct HostProxyFifo::Impl {
struct Fifo::Impl {
UniqueCudaHostPtr<ProxyTrigger[]> triggers;
UniqueCudaPtr<uint64_t> head;
UniqueCudaPtr<uint64_t> tailReplica;
const int size;
// allocated on the host. Only accessed by the host. This is a copy of the
// value pointed to by fifoTailDev and the invariant is that
@@ -28,28 +26,33 @@ struct HostProxyFifo::Impl {
// for transferring fifo tail
CudaStreamWithFlags stream;
Impl()
: triggers(makeUniqueCudaHost<ProxyTrigger[]>(MSCCLPP_PROXY_FIFO_SIZE)),
Impl(int size)
: triggers(makeUniqueCudaHost<ProxyTrigger[]>(size)),
head(allocUniqueCuda<uint64_t>()),
tailReplica(allocUniqueCuda<uint64_t>()),
size(size),
hostTail(0),
stream(cudaStreamNonBlocking) {}
};
MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo() : pimpl(std::make_unique<Impl>()) {}
MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo() = default;
MSCCLPP_API_CPP Fifo::Fifo(int size) : pimpl(std::make_unique<Impl>(size)) {}
MSCCLPP_API_CPP Fifo::~Fifo() = default;
MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger) {
__m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->triggers.get()[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]);
_mm_store_si128((__m128i*)trigger, xmm0);
MSCCLPP_API_CPP ProxyTrigger Fifo::poll() {
ProxyTrigger trigger;
volatile ProxyTrigger* ptr =
reinterpret_cast<volatile ProxyTrigger*>(&pimpl->triggers.get()[pimpl->hostTail % pimpl->size]);
trigger.fst = ptr->fst;
trigger.snd = ptr->snd;
return trigger;
}
MSCCLPP_API_CPP void HostProxyFifo::pop() {
*(volatile uint64_t*)(&pimpl->triggers.get()[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
MSCCLPP_API_CPP void Fifo::pop() {
*(volatile uint64_t*)(&pimpl->triggers.get()[pimpl->hostTail % pimpl->size]) = 0;
(pimpl->hostTail)++;
}
MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) {
MSCCLPP_API_CPP void Fifo::flushTail(bool sync) {
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can
// make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t),
@@ -59,12 +62,15 @@ MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) {
}
}
MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo() {
DeviceProxyFifo deviceFifo;
deviceFifo.triggers = pimpl->triggers.get();
deviceFifo.head = pimpl->head.get();
deviceFifo.tailReplica = pimpl->tailReplica.get();
return deviceFifo;
MSCCLPP_API_CPP int Fifo::size() const { return pimpl->size; }
MSCCLPP_API_CPP FifoDeviceHandle Fifo::deviceHandle() {
FifoDeviceHandle deviceHandle;
deviceHandle.triggers = pimpl->triggers.get();
deviceHandle.head = pimpl->head.get();
deviceHandle.tailReplica = pimpl->tailReplica.get();
deviceHandle.size = pimpl->size;
return deviceHandle;
}
} // namespace mscclpp

View File

@@ -16,8 +16,6 @@
#include "api.h"
#include "debug.h"
#define MAXCONNECTIONS 64
namespace mscclpp {
IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
@@ -54,8 +52,10 @@ const void* IbMr::getBuff() const { return this->buff; }
uint32_t IbMr::getLkey() const { return this->mr->lkey; }
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
this->cq = ibv_create_cq(ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
int maxWrPerSend)
: maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) {
this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0);
if (this->cq == nullptr) {
std::stringstream err;
err << "ibv_create_cq failed (errno " << errno << ")";
@@ -68,8 +68,8 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
qpInitAttr.send_cq = this->cq;
qpInitAttr.recv_cq = this->cq;
qpInitAttr.qp_type = IBV_QPT_RC;
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qpInitAttr.cap.max_send_wr = maxSendWr;
qpInitAttr.cap.max_recv_wr = maxRecvWr;
qpInitAttr.cap.max_send_sge = 1;
qpInitAttr.cap.max_recv_sge = 1;
qpInitAttr.cap.max_inline_data = 0;
@@ -118,9 +118,9 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
}
this->qp = _qp;
this->wrn = 0;
this->wrs = std::make_unique<ibv_send_wr[]>(MSCCLPP_IB_MAX_SENDS);
this->sges = std::make_unique<ibv_sge[]>(MSCCLPP_IB_MAX_SENDS);
this->wcs = std::make_unique<ibv_wc[]>(MSCCLPP_IB_CQ_POLL_NUM);
this->wrs = std::make_unique<ibv_send_wr[]>(maxWrPerSend);
this->sges = std::make_unique<ibv_sge[]>(maxWrPerSend);
this->wcs = std::make_unique<ibv_wc[]>(maxCqPollNum);
}
IbQp::~IbQp() {
@@ -182,9 +182,9 @@ void IbQp::rts() {
}
IbQp::WrInfo IbQp::getNewWrInfo() {
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
if (this->wrn >= this->maxWrPerSend) {
std::stringstream err;
err << "too many outstanding work requests. limit is " << MSCCLPP_IB_MAX_SENDS;
err << "too many outstanding work requests. limit is " << this->maxWrPerSend;
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
}
int wrn = this->wrn;
@@ -269,7 +269,7 @@ void IbQp::postRecv(uint64_t wrId) {
}
}
int IbQp::pollCq() { return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs.get()); }
int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); }
IbQpInfo& IbQp::getInfo() { return this->info; }
@@ -335,7 +335,8 @@ int IbCtx::getAnyActivePort() const {
return -1;
}
IbQp* IbCtx::createQp(int port /*=-1*/) {
IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int port /*=-1*/) {
if (port == -1) {
port = this->getAnyActivePort();
if (port == -1) {
@@ -344,7 +345,7 @@ IbQp* IbCtx::createQp(int port /*=-1*/) {
} else if (!this->isPortUsable(port)) {
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError);
}
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend));
return qps.back().get();
}

View File

@@ -4,9 +4,6 @@
#ifndef MSCCLPP_CONNECTION_HPP_
#define MSCCLPP_CONNECTION_HPP_
// TODO(saemal): make this configurable
#define MSCCLPP_POLLING_WAIT 3e7 // in microseconds
#include <cuda_runtime.h>
#include <mscclpp/core.hpp>
@@ -46,7 +43,7 @@ class CudaIpcConnection : public ConnectionBase {
uint64_t size) override;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
void flush() override;
void flush(int64_t timeoutUsec) override;
};
class IBConnection : public ConnectionBase {
@@ -59,7 +56,8 @@ class IBConnection : public ConnectionBase {
mscclpp::TransportInfo dstTransportInfo_;
public:
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
int maxWrPerSend, Communicator::Impl& commImpl);
Transport transport() override;
@@ -69,7 +67,7 @@ class IBConnection : public ConnectionBase {
uint64_t size) override;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
void flush() override;
void flush(int64_t timeoutUsec) override;
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override;

View File

@@ -8,11 +8,6 @@
#include <memory>
#include <string>
#define MSCCLPP_IB_CQ_SIZE 1024
#define MSCCLPP_IB_CQ_POLL_NUM 1
#define MSCCLPP_IB_MAX_SENDS 64
#define MSCCLPP_IB_MAX_DEVS 8
// Forward declarations of IB structures
struct ibv_context;
struct ibv_pd;
@@ -84,7 +79,8 @@ class IbQp {
ibv_sge* sge;
};
IbQp(ibv_context* ctx, ibv_pd* pd, int port);
IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
int maxWrPerSend);
WrInfo getNewWrInfo();
IbQpInfo info;
@@ -96,6 +92,9 @@ class IbQp {
std::unique_ptr<ibv_sge[]> sges;
int wrn;
const int maxCqPollNum;
const int maxWrPerSend;
friend class IbCtx;
};
@@ -104,7 +103,7 @@ class IbCtx {
IbCtx(const std::string& devName);
~IbCtx();
IbQp* createQp(int port = -1);
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1);
const IbMr* registerMr(void* buff, std::size_t size);
const std::string& getDevName() const;
@@ -122,4 +121,4 @@ class IbCtx {
} // namespace mscclpp
#endif // MSCCLPP_IB_HPP_
#endif // MSCCLPP_IB_HPP_

View File

@@ -6,6 +6,8 @@
#include <fstream>
#include <mscclpp/cuda_utils.hpp>
#include "api.h"
// Convert a logical cudaDev index to the NVML device minor number
static const std::string getBusId(int cudaDev) {
// On most systems, the PCI bus ID comes back as in the 0000:00:00.0
@@ -22,7 +24,7 @@ static const std::string getBusId(int cudaDev) {
namespace mscclpp {
int getDeviceNumaNode(int cudaDev) {
MSCCLPP_API_CPP int getDeviceNumaNode(int cudaDev) {
std::string busId = getBusId(cudaDev);
std::string file_str = "/sys/bus/pci/devices/" + busId + "/numa_node";
std::ifstream file(file_str);
@@ -37,7 +39,7 @@ int getDeviceNumaNode(int cudaDev) {
return numaNode;
}
void numaBind(int node) {
MSCCLPP_API_CPP void numaBind(int node) {
int totalNumNumaNodes = numa_num_configured_nodes();
if (node < 0 || node >= totalNumNumaNodes) {
throw Error(

View File

@@ -15,14 +15,13 @@ namespace mscclpp {
const int ProxyStopCheckPeriod = 1000;
// Unless explicitly requested, a flush of the tail to device memory is triggered for every ProxyFlushPeriod.
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
// As long as the FIFO size is large enough, having a stale tail is not a problem.
const int ProxyFlushPeriod = 4;
static_assert(MSCCLPP_PROXY_FIFO_SIZE >= ProxyFlushPeriod, "MSCCLPP_PROXY_FIFO_SIZE is too small");
struct Proxy::Impl {
ProxyHandler handler;
std::function<void()> threadInit;
HostProxyFifo fifo;
Fifo fifo;
std::thread service;
std::atomic_bool running;
@@ -53,10 +52,12 @@ MSCCLPP_API_CPP void Proxy::start() {
pimpl->threadInit();
ProxyHandler handler = this->pimpl->handler;
HostProxyFifo& fifo = this->pimpl->fifo;
Fifo& fifo = this->pimpl->fifo;
std::atomic_bool& running = this->pimpl->running;
ProxyTrigger trigger;
int flushPeriod = std::min(fifo.size(), ProxyFlushPeriod);
int runCnt = ProxyStopCheckPeriod;
uint64_t flushCnt = 0;
for (;;) {
@@ -67,19 +68,19 @@ MSCCLPP_API_CPP void Proxy::start() {
}
}
// Poll to see if we are ready to send anything
fifo.poll(&trigger);
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
continue; // there is one in progress
trigger = fifo.poll();
if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers
continue; // there is one in progress
}
trigger.snd ^= ((uint64_t)1 << (uint64_t)63); // this is where the last bit of snd is reverted.
ProxyHandlerResult result = handler(trigger);
// Send completion: reset only the high 64 bits
fifo.pop();
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
// request.
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
// Flush the tail to device memory. This is either triggered every flushPeriod to make sure that the fifo can make
// progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
if ((++flushCnt % flushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
fifo.flushTail();
}
@@ -107,6 +108,6 @@ MSCCLPP_API_CPP void Proxy::stop() {
}
}
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() { return pimpl->fifo; }
MSCCLPP_API_CPP Fifo& Proxy::fifo() { return pimpl->fifo; }
} // namespace mscclpp

View File

@@ -1,31 +1,36 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <mscclpp/numa.hpp>
#include <mscclpp/proxy_channel.hpp>
#include "api.h"
#include "debug.h"
#include "numa.hpp"
namespace mscclpp {
MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore,
DeviceProxyFifo fifo)
FifoDeviceHandle fifo)
: semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {}
MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src)
: proxyChan_(proxyChan), dst_(dst), src_(src) {}
MSCCLPP_API_CPP ProxyService::ProxyService(Communicator& communicator)
: communicator_(communicator),
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
MSCCLPP_API_CPP ProxyService::ProxyService()
: proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
deviceNumaNode = getDeviceNumaNode(cudaDevice);
}
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Connection> connection) {
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator_, connection));
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection) {
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator, connection));
return semaphores_.size() - 1;
}
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore) {
semaphores_.push_back(semaphore);
return semaphores_.size() - 1;
}
@@ -38,8 +43,8 @@ MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(Se
return semaphores_[id];
}
MSCCLPP_API_CPP ProxyChannel ProxyService::deviceChannel(SemaphoreId id) {
return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceFifo());
MSCCLPP_API_CPP ProxyChannel ProxyService::proxyChannel(SemaphoreId id) {
return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceHandle());
}
MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_.start(); }
@@ -78,14 +83,12 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
return result;
}
template <>
DeviceHandle<ProxyChannel> deviceHandle(ProxyChannel&& proxyChannel) {
return proxyChannel;
MSCCLPP_API_CPP ProxyChannel::DeviceHandle ProxyChannel::deviceHandle() const {
return ProxyChannel::DeviceHandle{.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_, .fifo_ = fifo_};
}
template <>
DeviceHandle<SimpleProxyChannel> deviceHandle(SimpleProxyChannel&& simpleProxyChannel) {
return simpleProxyChannel;
MSCCLPP_API_CPP SimpleProxyChannel::DeviceHandle SimpleProxyChannel::deviceHandle() const {
return SimpleProxyChannel::DeviceHandle{.proxyChan_ = proxyChan_.deviceHandle(), .dst_ = dst_, .src_ = src_};
}
} // namespace mscclpp

View File

@@ -207,8 +207,8 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz
CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice));
}
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm,
mscclpp::ProxyService& channelService, int* data_d, size_t dataSize) {
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::ProxyService& proxyService,
int* data_d, size_t dataSize) {
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
@@ -226,7 +226,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
transport = ibTransport;
}
// Connect with all other ranks
semaphoreIds.push_back(channelService.addSemaphore(comm.connectOnSetup(r, 0, transport)));
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, comm.connectOnSetup(r, 0, transport)));
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
localMemories.push_back(memory);
comm.sendMemoryOnSetup(memory, r, 0);
@@ -238,8 +238,8 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
channelService.deviceChannel(semaphoreIds[i]), channelService.addMemory(remoteMemories[i].get()),
channelService.addMemory(localMemories[i]))));
proxyService.proxyChannel(semaphoreIds[i]), proxyService.addMemory(remoteMemories[i].get()),
proxyService.addMemory(localMemories[i]))));
}
assert(proxyChannels.size() < sizeof(constProxyChans) / sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>));
@@ -396,16 +396,16 @@ int main(int argc, const char* argv[]) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
bootstrap->initialize(ip_port);
mscclpp::Communicator comm(bootstrap);
mscclpp::ProxyService channelService(comm);
mscclpp::ProxyService proxyService;
if (rank == 0) printf("Initializing data for allgather test\n");
initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d);
if (rank == 0) printf("Setting up the connection in MSCCL++\n");
setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize);
setupMscclppConnections(rank, world_size, comm, proxyService, data_d, dataSize);
if (rank == 0) printf("Launching MSCCL++ proxy threads\n");
channelService.startProxy();
proxyService.startProxy();
if (rank == 0) printf("Testing the correctness of AllGather implementation\n");
cudaStream_t stream;
@@ -480,7 +480,7 @@ int main(int argc, const char* argv[]) {
bootstrap->allGather(tmp, sizeof(int));
if (rank == 0) printf("Stopping MSCCL++ proxy threads\n");
channelService.stopProxy();
proxyService.stopProxy();
} catch (std::exception& e) {
// todo: throw exceptions in the implementation and process them here

View File

@@ -4,9 +4,9 @@
#include <mscclpp/core.hpp>
#include <mscclpp/cuda_utils.hpp>
#include <mscclpp/fifo.hpp>
#include <mscclpp/numa.hpp>
#include <mscclpp/proxy.hpp>
#include <mscclpp/semaphore.hpp>
#include <numa.hpp>
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
#include "mpi.h"
@@ -45,7 +45,7 @@ static double getTime(void) {
return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec;
}
__global__ void kernel(int r, int nranks, mscclpp::DeviceProxyFifo fifo,
__global__ void kernel(int r, int nranks, mscclpp::FifoDeviceHandle fifo,
mscclpp::Host2DeviceSemaphore::DeviceHandle* handles, int handleIndex) {
int tid = threadIdx.x;
__syncthreads();
@@ -188,7 +188,7 @@ class MyProxyService {
void stop() { proxy_.stop(); }
mscclpp::HostProxyFifo& fifo() { return proxy_.fifo(); }
mscclpp::Fifo& fifo() { return proxy_.fifo(); }
mscclpp::Host2DeviceSemaphore::DeviceHandle getDeviceHandle1(int r) { return deviceSemaphores1_[r]->deviceHandle(); }
@@ -261,7 +261,7 @@ int main(int argc, char* argv[]) {
if (rank == 0) printf("Launching MSCCL++ proxy threads\n");
proxyService.start();
mscclpp::DeviceProxyFifo fifo = proxyService.fifo().deviceFifo();
mscclpp::FifoDeviceHandle fifo = proxyService.fifo().deviceHandle();
if (rank == 0) printf("Testing the correctness of AllGather implementation\n");
cudaStream_t stream;
CUCHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));

View File

@@ -3,8 +3,6 @@
#include <mpi.h>
#include <mscclpp/config.hpp>
#include "mp_unit_tests.hpp"
void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
@@ -88,10 +86,6 @@ TEST_F(BootstrapTest, ExitBeforeConnect) {
}
TEST_F(BootstrapTest, TimeoutWithId) {
// Set bootstrap timeout to 1 second
mscclpp::Config* cfg = mscclpp::Config::getInstance();
cfg->setBootstrapConnectionTimeoutConfig(1);
mscclpp::Timer timer;
// All ranks initialize a bootstrap with their own id (will hang)
@@ -99,7 +93,8 @@ TEST_F(BootstrapTest, TimeoutWithId) {
mscclpp::UniqueId id = bootstrap->createUniqueId();
try {
bootstrap->initialize(id);
// Set bootstrap timeout to 1 second
bootstrap->initialize(id, 1);
} catch (const mscclpp::Error& e) {
ASSERT_EQ(e.getErrorCode(), mscclpp::ErrorCode::Timeout);
}

View File

@@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() {
bootstrap->initialize(id);
ibCtx = std::make_shared<mscclpp::IbCtx>(ibDevName);
qp = ibCtx->createQp();
qp = ibCtx->createQp(1024, 1, 8192, 0, 64);
qpInfo[gEnv->rank] = qp->getInfo();
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));

View File

@@ -124,6 +124,9 @@ class CommunicatorTest : public CommunicatorTestBase {
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remoteMemory;
};
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
class ProxyChannelOneToOneTest : public CommunicatorTestBase {
protected:
void SetUp() override;
@@ -134,7 +137,7 @@ class ProxyChannelOneToOneTest : public CommunicatorTestBase {
void testPacketPingPong(bool useIbOnly);
void testPacketPingPongPerf(bool useIbOnly);
std::shared_ptr<mscclpp::ProxyService> channelService;
std::shared_ptr<mscclpp::ProxyService> proxyService;
};
class SmChannelOneToOneTest : public CommunicatorTestBase {

View File

@@ -5,21 +5,18 @@
#include "mp_unit_tests.hpp"
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
void ProxyChannelOneToOneTest::SetUp() {
// Use only two ranks
setNumRanksToUse(2);
CommunicatorTestBase::SetUp();
channelService = std::make_shared<mscclpp::ProxyService>(*communicator.get());
proxyService = std::make_shared<mscclpp::ProxyService>();
}
void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
void ProxyChannelOneToOneTest::setupMeshConnections(
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels, bool useIbOnly, void* sendBuff,
size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) {
void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels,
bool useIbOnly, void* sendBuff, size_t sendBuffBytes,
void* recvBuff, size_t recvBuffBytes) {
const int rank = communicator->bootstrap()->getRank();
const int worldSize = communicator->bootstrap()->getNranks();
const bool isInPlace = (recvBuff == nullptr);
@@ -52,12 +49,11 @@ void ProxyChannelOneToOneTest::setupMeshConnections(
communicator->setup();
mscclpp::SemaphoreId cid = channelService->addSemaphore(conn);
mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, conn);
communicator->setup();
proxyChannels.emplace_back(mscclpp::deviceHandle(
mscclpp::SimpleProxyChannel(channelService->deviceChannel(cid), channelService->addMemory(remoteMemory.get()),
channelService->addMemory(sendBufRegMem))));
proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemory.get()),
proxyService->addMemory(sendBufRegMem));
}
}
@@ -121,15 +117,18 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
const int nElem = 4 * 1024 * 1024;
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
setupMeshConnections(proxyChannels, true, buff.get(), nElem * sizeof(int));
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
ASSERT_EQ(proxyChannels.size(), 1);
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(),
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)));
channelService->startProxy();
proxyService->startProxy();
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
@@ -153,7 +152,7 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
EXPECT_EQ(*ret, 0);
channelService->stopProxy();
proxyService->stopProxy();
}
__device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer;
@@ -227,7 +226,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
const int nElem = 4 * 1024 * 1024;
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t);
@@ -238,13 +237,19 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
ASSERT_EQ(proxyChannels.size(), 1);
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(),
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
for (auto& proxyChannel : proxyChannels) {
proxyChannelHandles.push_back(proxyChannel.deviceHandle());
}
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)));
mscclpp::DeviceSyncer syncer = {};
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
channelService->startProxy();
proxyService->startProxy();
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
@@ -280,7 +285,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
communicator->bootstrap()->barrier();
channelService->stopProxy();
proxyService->stopProxy();
}
void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
@@ -288,7 +293,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
const int nElem = 4 * 1024 * 1024;
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t);
@@ -299,13 +304,19 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
ASSERT_EQ(proxyChannels.size(), 1);
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(),
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
for (auto& proxyChannel : proxyChannels) {
proxyChannelHandles.push_back(proxyChannel.deviceHandle());
}
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)));
mscclpp::DeviceSyncer syncer = {};
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
channelService->startProxy();
proxyService->startProxy();
auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info();
const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name());
@@ -330,7 +341,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)nTries << " us/iter\n";
}
channelService->stopProxy();
proxyService->stopProxy();
}
TEST_F(ProxyChannelOneToOneTest, PacketPingPong) { testPacketPingPong(false); }

View File

@@ -5,8 +5,6 @@
#include "mp_unit_tests.hpp"
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
void SmChannelOneToOneTest::SetUp() {
// Need at least two ranks within a node
if (gEnv->nRanksPerNode < 2) {

View File

@@ -210,6 +210,7 @@ __global__ void allgather3(int rank, int worldSize) {
if (tid == 0) {
mscclpp::ProxyTrigger trigger;
trigger.fst = MAGIC;
trigger.snd = 0;
// offload all the work to the proxy
uint64_t currentFifoHead = proxyChan.fifo_.push(trigger);
// wait for the work to be done in cpu side
@@ -278,23 +279,24 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
nBlocksForLocalAllGather);
}
class AllGatherChannelService : public mscclpp::BaseProxyService {
class AllGatherProxyService : public mscclpp::BaseProxyService {
public:
AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank, int cudaDevice);
AllGatherProxyService(int worldSize, int rank, int cudaDevice);
void startProxy() override { proxy_.start(); }
void stopProxy() override { proxy_.stop(); }
void setSendBytes(size_t sendBytes) { this->sendBytes_ = sendBytes; }
void addRemoteMemory(mscclpp::RegisteredMemory memory) { remoteMemories_.push_back(memory); }
void setLocalMemory(mscclpp::RegisteredMemory memory) { localMemory_ = memory; }
mscclpp::SemaphoreId addSemaphore(std::shared_ptr<mscclpp::Connection> connection) {
semaphores_.push_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(communicator_, connection));
mscclpp::SemaphoreId buildAndAddSemaphore(mscclpp::Communicator& communicator,
std::shared_ptr<mscclpp::Connection> connection) {
semaphores_.push_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(communicator, connection));
return semaphores_.size() - 1;
}
std::vector<DeviceHandle<mscclpp::ProxyChannel>> deviceChannels() {
std::vector<DeviceHandle<mscclpp::ProxyChannel>> proxyChannels() {
std::vector<DeviceHandle<mscclpp::ProxyChannel>> result;
for (auto& semaphore : semaphores_) {
result.push_back(
mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceFifo())));
mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceHandle())));
}
return result;
}
@@ -306,7 +308,6 @@ class AllGatherChannelService : public mscclpp::BaseProxyService {
size_t sendBytes_;
mscclpp::Proxy proxy_;
mscclpp::Communicator& communicator_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
std::vector<mscclpp::RegisteredMemory> remoteMemories_;
mscclpp::RegisteredMemory localMemory_;
@@ -314,10 +315,8 @@ class AllGatherChannelService : public mscclpp::BaseProxyService {
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger triggerRaw);
};
AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank,
int cudaDevice)
: communicator_(communicator),
worldSize_(worldSize),
AllGatherProxyService::AllGatherProxyService(int worldSize, int rank, int cudaDevice)
: worldSize_(worldSize),
sendBytes_(0),
rank_(rank),
cudaDevice_(cudaDevice),
@@ -327,7 +326,7 @@ AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communic
numaBind(deviceNumaNode);
}) {}
mscclpp::ProxyHandlerResult AllGatherChannelService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
size_t offset = rank_ * sendBytes_;
if (triggerRaw.fst != MAGIC) {
// this is not a valid trigger
@@ -432,7 +431,7 @@ void AllGatherTestColl::setupCollTest(size_t size) {
paramCount_ = base;
expectedCount_ = recvCount_;
if (isUsingHostOffload(kernelNum_)) {
auto service = std::dynamic_pointer_cast<AllGatherChannelService>(chanService_);
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
service->setSendBytes(sendCount_ * typeSize_);
}
mscclpp::DeviceSyncer syncer = {};
@@ -459,7 +458,7 @@ class AllGatherTestEngine : public BaseTestEngine {
std::vector<void*> getSendBuff() override;
void* getRecvBuff() override;
void* getScratchBuff() override;
std::shared_ptr<mscclpp::BaseProxyService> createChannelService() override;
std::shared_ptr<mscclpp::BaseProxyService> createProxyService() override;
private:
void* getExpectedBuff() override;
@@ -492,31 +491,31 @@ void AllGatherTestEngine::setupConnections() {
CUDATHROW(cudaMemcpyToSymbol(constSmChans, smChannelHandles.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelHandles.size()));
} else {
auto service = std::dynamic_pointer_cast<AllGatherChannelService>(chanService_);
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
setupMeshConnections(devProxyChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
[&](std::vector<std::shared_ptr<mscclpp::Connection>> conns,
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteMemories,
const mscclpp::RegisteredMemory& localMemory) {
std::vector<mscclpp::SemaphoreId> semaphoreIds;
for (size_t i = 0; i < conns.size(); ++i) {
service->addSemaphore(conns[i]);
service->buildAndAddSemaphore(*comm_, conns[i]);
service->addRemoteMemory(remoteMemories[i].get());
}
service->setLocalMemory(localMemory);
comm_->setup();
});
auto proxyChannels = service->deviceChannels();
auto proxyChannels = service->proxyChannels();
assert(proxyChannels.size() < sizeof(constRawProxyChan) / sizeof(DeviceHandle<mscclpp::ProxyChannel>));
CUDATHROW(cudaMemcpyToSymbol(constRawProxyChan, proxyChannels.data(),
sizeof(DeviceHandle<mscclpp::ProxyChannel>) * proxyChannels.size()));
}
}
std::shared_ptr<mscclpp::BaseProxyService> AllGatherTestEngine::createChannelService() {
std::shared_ptr<mscclpp::BaseProxyService> AllGatherTestEngine::createProxyService() {
if (isUsingHostOffload(args_.kernelNum)) {
return std::make_shared<AllGatherChannelService>(*comm_, args_.totalRanks, args_.rank, args_.gpuNum);
return std::make_shared<AllGatherProxyService>(args_.totalRanks, args_.rank, args_.gpuNum);
} else {
return std::make_shared<mscclpp::ProxyService>(*comm_);
return std::make_shared<mscclpp::ProxyService>();
}
}

View File

@@ -109,7 +109,7 @@ __device__ void localReduceScatter(int* buff, int* scratch, int rank, int nRanks
int prePeerRecvId = (preRemoteRecvFromRank < rank) ? preRemoteRecvFromRank : preRemoteRecvFromRank - 1;
// overlap communication and computation
mscclpp::SimpleProxyChannel& preDevFstRecvChan = constDevFstRoundChans[prePeerRecvId];
DeviceHandle<mscclpp::SimpleProxyChannel>& preDevFstRecvChan = constDevFstRoundChans[prePeerRecvId];
if (isComm) {
preDevFstRecvChan.wait();
devFstSendChan.putWithSignal(dstOffset, srcOffset, nelems * sizeof(int));
@@ -563,7 +563,8 @@ __global__ void allreduce0(int* buff, int* scratch, int rank, int worldSize, siz
}
}
__global__ void allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
__global__ void __launch_bounds__(1024)
allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
int isComm = (threadIdx.x == 0) && (blockIdx.x == 0);
int remoteSendRank = (rank + 1) % worldSize;
int remoteRecvRank = (rank + worldSize - 1) % worldSize;
@@ -686,7 +687,7 @@ __global__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getP
// Channel to a remote peer that has the same local rank as me
int localRank = rank % nRanksPerNode;
mscclpp::SimpleProxyChannel proxyChan = constDevFstRoundChans[localRank];
DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan = constDevFstRoundChans[localRank];
// Flag for packets. Initially 1
uint32_t flag = (uint32_t)globalFlag;
@@ -779,8 +780,8 @@ __global__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getP
}
}
__global__ void allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
__global__ void __launch_bounds__(1024)
allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
reduceScatter(buff, scratch, rank, nRanksPerNode, worldSize, nelems);
if (threadIdx.x == 0 && blockIdx.x == 0) {
allGather(rank, worldSize, nRanksPerNode, nelems / worldSize);

View File

@@ -16,7 +16,7 @@ void* localSendBuff;
__device__ void localAlltoall(int rank, int nRanksPerNode, size_t nElements) {
int remoteRank = (blockIdx.x < rank) ? blockIdx.x : blockIdx.x + 1;
for (int i = 1; i < nRanksPerNode; i++) {
mscclpp::SimpleProxyChannel proxyChan = constProxyChans[blockIdx.x];
DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan = constProxyChans[blockIdx.x];
if (threadIdx.x == 0 && remoteRank % nRanksPerNode == (rank + i) % nRanksPerNode) {
proxyChan.putWithSignalAndFlush(rank * nElements * sizeof(int), remoteRank * nElements * sizeof(int),
nElements * sizeof(int));

View File

@@ -16,9 +16,17 @@ def load_perf_file(perf_fine: str) -> dict:
"time": data["time"],
}
if "target" in data:
res[(data["name"], data["kernel"], data["ranks"], data["ranksPerNode"], data["size"])]["target"] = data[
res[
(
data["name"],
data["kernel"],
data["ranks"],
data["ranksPerNode"],
data["size"],
)
][
"target"
]
] = data["target"]
return res

View File

@@ -335,7 +335,7 @@ void BaseTestEngine::bootstrap() {
}
void BaseTestEngine::setupTest() {
this->chanService_ = this->createChannelService();
this->chanService_ = this->createProxyService();
this->setupConnections();
this->chanService_->startProxy();
this->coll_->setChanService(this->chanService_);
@@ -357,8 +357,8 @@ size_t BaseTestEngine::checkData() {
return nErrors;
}
std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createChannelService() {
return std::make_shared<mscclpp::ProxyService>(*comm_);
std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createProxyService() {
return std::make_shared<mscclpp::ProxyService>();
}
void BaseTestEngine::setupMeshConnectionsInternal(
@@ -416,8 +416,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp
auto service = std::dynamic_pointer_cast<mscclpp::ProxyService>(chanService_);
for (size_t i = 0; i < connections.size(); ++i) {
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
service->deviceChannel(service->addSemaphore(connections[i])), service->addMemory(remoteRegMemories[i].get()),
service->addMemory(inputBufRegMem))));
service->proxyChannel(service->buildAndAddSemaphore(*comm_, connections[i])),
service->addMemory(remoteRegMemories[i].get()), service->addMemory(inputBufRegMem))));
}
}
@@ -498,7 +498,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
smSemaphores.emplace(cid, std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, connections[cid]));
} else {
connIdToSemId[cid] = service->addSemaphore(connections[cid]);
connIdToSemId[cid] = service->buildAndAddSemaphore(*comm_, connections[cid]);
}
}
comm_->setup();
@@ -513,7 +513,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
throw std::runtime_error("IB transport requires putPacketBuff and getPacketBuff");
}
proxyChannels.emplace_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
service->deviceChannel(connIdToSemId[cid]), service->addMemory(remoteRegMemories[cid].get()),
service->proxyChannel(connIdToSemId[cid]), service->addMemory(remoteRegMemories[cid].get()),
service->addMemory(putPacketBufRegMem))));
}
}

View File

@@ -97,7 +97,7 @@ class BaseTestEngine {
private:
virtual void setupConnections() = 0;
virtual std::shared_ptr<mscclpp::BaseProxyService> createChannelService();
virtual std::shared_ptr<mscclpp::BaseProxyService> createProxyService();
virtual void* getExpectedBuff() = 0;
double benchTime();

View File

@@ -5,40 +5,40 @@
#include <mscclpp/cuda_utils.hpp>
#include <mscclpp/fifo.hpp>
#include <mscclpp/numa.hpp>
#include <mscclpp/utils.hpp>
#include "numa.hpp"
#define ITER 10000 // should be larger than the FIFO size for proper testing
#define FLUSH_PERIOD (MSCCLPP_PROXY_FIFO_SIZE) // should not exceed MSCCLPP_PROXY_FIFO_SIZE
#define ITER 10000 // should be larger than MSCCLPP_PROXY_FIFO_SIZE for proper testing
__constant__ mscclpp::DeviceProxyFifo gFifoTestDeviceProxyFifo;
__constant__ mscclpp::FifoDeviceHandle gFifoTestFifoDeviceHandle;
__global__ void kernelFifoTest() {
if (threadIdx.x + blockIdx.x * blockDim.x != 0) return;
mscclpp::DeviceProxyFifo& fifo = gFifoTestDeviceProxyFifo;
mscclpp::FifoDeviceHandle& fifo = gFifoTestFifoDeviceHandle;
mscclpp::ProxyTrigger trigger;
for (uint64_t i = 1; i < ITER + 1; ++i) {
trigger.fst = i;
trigger.snd = i;
uint64_t curFifoHead = fifo.push(trigger);
if (i % FLUSH_PERIOD == 0) {
if (i % fifo.size == 0) {
fifo.sync(curFifoHead);
}
}
}
TEST(FifoTest, HostProxyFifo) {
ASSERT_LE(FLUSH_PERIOD, MSCCLPP_PROXY_FIFO_SIZE);
TEST(FifoTest, Fifo) {
int cudaNum;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaNum));
int numaNode = mscclpp::getDeviceNumaNode(cudaNum);
mscclpp::numaBind(numaNode);
mscclpp::HostProxyFifo hostFifo;
mscclpp::DeviceProxyFifo devFifo = hostFifo.deviceFifo();
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gFifoTestDeviceProxyFifo, &devFifo, sizeof(devFifo)));
mscclpp::Fifo hostFifo;
if (hostFifo.size() >= ITER) {
FAIL() << "ITER is too small for proper testing.";
}
mscclpp::FifoDeviceHandle devFifo = hostFifo.deviceHandle();
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gFifoTestFifoDeviceHandle, &devFifo, sizeof(devFifo)));
kernelFifoTest<<<1, 1>>>();
MSCCLPP_CUDATHROW(cudaGetLastError());
@@ -51,17 +51,19 @@ TEST(FifoTest, HostProxyFifo) {
uint64_t flushCnt = 0;
mscclpp::Timer timer(3);
for (uint64_t i = 0; i < ITER; ++i) {
while (trigger.fst == 0) {
hostFifo.poll(&trigger);
while (trigger.fst == 0 || trigger.snd == 0) {
trigger = hostFifo.poll();
if (spin++ > 1000000) {
FAIL() << "Polling is stuck.";
}
}
// see `src/proxy.cc` for the reason of this line
trigger.snd ^= ((uint64_t)1 << (uint64_t)63);
ASSERT_TRUE(trigger.fst == (i + 1));
ASSERT_TRUE(trigger.snd == (i + 1));
hostFifo.pop();
if ((++flushCnt % FLUSH_PERIOD) == 0) {
if ((++flushCnt % hostFifo.size()) == 0) {
hostFifo.flushTail();
}
trigger.fst = 0;
@@ -70,7 +72,7 @@ TEST(FifoTest, HostProxyFifo) {
hostFifo.flushTail(true);
std::stringstream ss;
ss << "FifoTest.HostProxyFifo: " << (float)timer.elapsed() / ITER << " us/iter\n";
ss << "FifoTest.Fifo: " << (float)timer.elapsed() / ITER << " us/iter\n";
std::cout << ss.str();
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());

View File

@@ -4,8 +4,7 @@
#include <gtest/gtest.h>
#include <mscclpp/cuda_utils.hpp>
#include "numa.hpp"
#include <mscclpp/numa.hpp>
TEST(NumaTest, Basic) {
int num;

View File

@@ -2,10 +2,8 @@
# Licensed under the MIT License.
import argparse
import os
import json
from queue import Queue
import os
def parse_npkit_event_header(npkit_event_header_path):
@@ -118,7 +116,10 @@ def parse_gpu_event_file(npkit_dump_dir, npkit_event_def, rank, buf_idx, gpu_clo
)
event_type_to_seq[event_type] += 1
else:
gpu_events[-1]["args"] = {"size": parsed_gpu_event["size"], "rsvd": parsed_gpu_event["rsvd"]}
gpu_events[-1]["args"] = {
"size": parsed_gpu_event["size"],
"rsvd": parsed_gpu_event["rsvd"],
}
delta_time = gpu_events[-1]["ts"] - gpu_events[-2]["ts"]
gpu_events[-1]["args"]["bw (GB/s)"] = gpu_events[-1]["args"]["size"] / delta_time / 1e3
raw_content_idx += raw_event_size
@@ -238,7 +239,12 @@ def convert_npkit_dump_to_trace(npkit_dump_dir, output_dir, npkit_event_def):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--npkit_dump_dir", type=str, required=True, help="NPKit dump directory.")
parser.add_argument("--npkit_event_header_path", type=str, required=True, help="Path to npkit_event.h.")
parser.add_argument(
"--npkit_event_header_path",
type=str,
required=True,
help="Path to npkit_event.h.",
)
parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory.")
args = parser.parse_args()