mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Change device handle interfaces & others (#142)
* Changed device handle interfaces * Changed proxy service interfaces * Move device code into separate files * Fixed FIFO polling issues * Add configuration arguments in several interface functions --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com> Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: root <root@a100-saemal0.qxveptpukjsuthqvv514inp03c.gx.internal.cloudapp.net>
This commit is contained in:
@@ -17,8 +17,8 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall,-Wextra")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
|
||||
|
||||
# clang-format targets
|
||||
include(${PROJECT_SOURCE_DIR}/cmake/AddClangFormatTargets.cmake)
|
||||
# Format targets
|
||||
include(${PROJECT_SOURCE_DIR}/cmake/AddFormatTargets.cmake)
|
||||
|
||||
# Options
|
||||
option(ENABLE_TRACE "Enable tracing" OFF)
|
||||
|
||||
@@ -67,7 +67,7 @@ mscclpp::Communicator comm(bootstrap);
|
||||
// Setup connections here using `comm`
|
||||
...
|
||||
// Construct the default proxy
|
||||
mscclpp::ProxyService proxyService(comm);
|
||||
mscclpp::ProxyService proxyService();
|
||||
// Start the proxy
|
||||
proxyService.startProxy();
|
||||
// Run the user application, i.e., launch GPU kernels here
|
||||
@@ -80,7 +80,7 @@ While the default implementation already enables any kinds of communication, MSC
|
||||
|
||||
```cpp
|
||||
// Proxy FIFO is obtained from mscclpp::Proxy on the host and copied to the device.
|
||||
__device__ mscclpp::DeviceProxyFifo fifo;
|
||||
__device__ mscclpp::FifoDeviceHandle fifo;
|
||||
__global__ void gpuKernel() {
|
||||
...
|
||||
// Only one thread is needed for the followings
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Add targets to run clang-format
|
||||
|
||||
find_program(CLANG_FORMAT clang-format)
|
||||
if(CLANG_FORMAT)
|
||||
message(STATUS "Found clang-format: ${CLANG_FORMAT}")
|
||||
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test)
|
||||
add_custom_target(check-format ALL
|
||||
COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
|
||||
)
|
||||
add_custom_target(format
|
||||
COMMAND ${CLANG_FORMAT} -style=file -i `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
|
||||
)
|
||||
else()
|
||||
message(STATUS "clang-format not found.")
|
||||
endif()
|
||||
38
cmake/AddFormatTargets.cmake
Normal file
38
cmake/AddFormatTargets.cmake
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Add targets to run clang-format and black
|
||||
|
||||
add_custom_target(check-format)
|
||||
add_custom_target(format)
|
||||
|
||||
find_program(CLANG_FORMAT clang-format)
|
||||
if(CLANG_FORMAT)
|
||||
message(STATUS "Found clang-format: ${CLANG_FORMAT}")
|
||||
set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test)
|
||||
add_custom_target(check-format-cpp ALL
|
||||
COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
|
||||
)
|
||||
add_dependencies(check-format check-format-cpp)
|
||||
add_custom_target(format-cpp
|
||||
COMMAND ${CLANG_FORMAT} -style=file -i `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu`
|
||||
)
|
||||
add_dependencies(format format-cpp)
|
||||
else()
|
||||
message(STATUS "clang-format not found.")
|
||||
endif()
|
||||
|
||||
find_program(BLACK black)
|
||||
if (BLACK)
|
||||
message(STATUS "Found black: ${BLACK}")
|
||||
add_custom_target(check-format-py
|
||||
COMMAND ${BLACK} --config ${PROJECT_SOURCE_DIR}/pyproject.toml --check ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test
|
||||
)
|
||||
add_dependencies(check-format check-format-py)
|
||||
add_custom_target(format-py
|
||||
COMMAND ${BLACK} --config ${PROJECT_SOURCE_DIR}/pyproject.toml ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test
|
||||
)
|
||||
add_dependencies(format format-py)
|
||||
else()
|
||||
message(STATUS, "black not found.")
|
||||
endif()
|
||||
@@ -1,27 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef MSCCLPP_CONFIG_H_
|
||||
#define MSCCLPP_CONFIG_H_
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
class Config {
|
||||
public:
|
||||
int bootstrapConnectionTimeout = 30;
|
||||
|
||||
static Config* getInstance();
|
||||
int getBootstrapConnectionTimeoutConfig();
|
||||
void setBootstrapConnectionTimeoutConfig(int timeout);
|
||||
|
||||
private:
|
||||
Config() = default;
|
||||
Config(const Config&) = delete;
|
||||
Config& operator=(const Config&) = delete;
|
||||
|
||||
static Config instance_;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // end include guard
|
||||
@@ -61,11 +61,13 @@ class TcpBootstrap : public Bootstrap {
|
||||
|
||||
/// Initialize the @ref TcpBootstrap with a given unique ID.
|
||||
/// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with.
|
||||
void initialize(UniqueId uniqueId);
|
||||
/// @param timeoutSec The connection timeout in seconds.
|
||||
void initialize(UniqueId uniqueId, int64_t timeoutSec = 30);
|
||||
|
||||
/// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port".
|
||||
/// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port".
|
||||
void initialize(const std::string& ifIpPortTrio);
|
||||
/// @param timeoutSec The connection timeout in seconds.
|
||||
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30);
|
||||
|
||||
/// Return the rank of the process.
|
||||
int getRank() override;
|
||||
@@ -384,7 +386,7 @@ class Connection {
|
||||
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
|
||||
|
||||
/// Flush any pending writes to the remote process.
|
||||
virtual void flush() = 0;
|
||||
virtual void flush(int64_t timeoutUsec = 3e7) = 0;
|
||||
|
||||
/// Get the rank of the remote process.
|
||||
///
|
||||
@@ -533,8 +535,14 @@ class Communicator {
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag of the connection for identifying it.
|
||||
/// @param transport The type of transport to be used.
|
||||
/// @param ibMaxCqSize The maximum number of completion queue entries for IB. Unused if transport is not IB.
|
||||
/// @param ibMaxCqPollNum The maximum number of completion queue entries to poll for IB. Unused if transport is not
|
||||
/// IB.
|
||||
/// @param ibMaxSendWr The maximum number of outstanding send work requests for IB. Unused if transport is not IB.
|
||||
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
|
||||
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
|
||||
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport);
|
||||
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024,
|
||||
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64);
|
||||
|
||||
/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
|
||||
///
|
||||
|
||||
@@ -7,90 +7,25 @@
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mscclpp/fifo_device.hpp>
|
||||
#include <mscclpp/poll.hpp>
|
||||
|
||||
#define MSCCLPP_PROXY_FIFO_SIZE 128
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy.
|
||||
///
|
||||
/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push
|
||||
/// ProxyTrigger elements and a single host proxy thread consumes these work elements.
|
||||
///
|
||||
struct alignas(16) ProxyTrigger {
|
||||
uint64_t fst, snd;
|
||||
};
|
||||
|
||||
/// A concurrent FIFO where multiple device threads can push work elements and a single host proxy thread consumes them.
|
||||
///
|
||||
/// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost
|
||||
/// infinity. There are two copies of the tail, one on the device, @ref DeviceProxyFifo::tailReplica, and another on the
|
||||
/// host, namely, hostTail. The host always has the "true" tail and occasionally pushes it to the copy on the device.
|
||||
/// Therefore, most of the time, the device has a stale version. The invariants are: tailReplica <= hostTail <= head.
|
||||
/// The @ref push() function increments head, hostTail is updated in @ref HostProxyFifo::pop(), and it occasionally
|
||||
/// flushes it to tailReplica via @ref HostProxyFifo::flushTail().
|
||||
///
|
||||
/// Duplicating the tail is a good idea because the FIFO is large enough, and we do not need frequent updates for the
|
||||
/// tail as there is usually enough space for device threads to push their work into.
|
||||
///
|
||||
struct DeviceProxyFifo {
|
||||
#ifdef __CUDACC__
|
||||
/// Push a trigger to the FIFO.
|
||||
///
|
||||
/// @param trigger The trigger to push.
|
||||
/// @return The new head of the FIFO.
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
|
||||
|
||||
// Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to
|
||||
// write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but
|
||||
// for the second condition we need to read CPU memory.
|
||||
// As volatile access is slow, we first check using the bare pointer and then use the volatile pointer if the
|
||||
// condition is not met.
|
||||
if (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *(this->tailReplica)) {
|
||||
OR_POLL_MAYBE_JAILBREAK(curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica),
|
||||
*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0,
|
||||
1000000);
|
||||
}
|
||||
|
||||
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
|
||||
return curFifoHead;
|
||||
}
|
||||
|
||||
/// Wait until there is a place in the FIFO to push a trigger.
|
||||
///
|
||||
/// @param curFifoHead The current head of the FIFO.
|
||||
__forceinline__ __device__ void sync(uint64_t curFifoHead) {
|
||||
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
|
||||
// to wait for cudaMemcpy to be done.
|
||||
OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]) != 0,
|
||||
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
/// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`.
|
||||
ProxyTrigger* triggers;
|
||||
/// Replica of the FIFO tail that is allocated on device.
|
||||
uint64_t* tailReplica;
|
||||
/// The FIFO head. Allocated on the device and only accessed by the device.
|
||||
uint64_t* head;
|
||||
};
|
||||
|
||||
/// A class representing a host proxy FIFO that can consume work elements pushed by device threads.
|
||||
class HostProxyFifo {
|
||||
class Fifo {
|
||||
public:
|
||||
/// Constructs a new @ref HostProxyFifo object.
|
||||
HostProxyFifo();
|
||||
/// Constructs a new @ref Fifo object.
|
||||
/// @param size The number of entires in the FIFO.
|
||||
Fifo(int size = 128);
|
||||
|
||||
/// Destroys the @ref HostProxyFifo object.
|
||||
~HostProxyFifo();
|
||||
/// Destroys the @ref Fifo object.
|
||||
~Fifo();
|
||||
|
||||
/// Polls the FIFO for a trigger.
|
||||
///
|
||||
/// @param trigger A pointer to the trigger to be filled.
|
||||
void poll(ProxyTrigger* trigger);
|
||||
/// Returns @ref ProxyTrigger which is the trigger at the head of fifo.
|
||||
ProxyTrigger poll();
|
||||
|
||||
/// Pops a trigger from the FIFO.
|
||||
void pop();
|
||||
@@ -100,10 +35,14 @@ class HostProxyFifo {
|
||||
/// @param sync If true, waits for the flush to complete before returning.
|
||||
void flushTail(bool sync = false);
|
||||
|
||||
/// Returns a @ref DeviceProxyFifo object representing the device FIFO.
|
||||
/// Return the FIFO size.
|
||||
/// @return The FIFO size.
|
||||
int size() const;
|
||||
|
||||
/// Returns a @ref FifoDeviceHandle object representing the device FIFO.
|
||||
///
|
||||
/// @return A @ref DeviceProxyFifo object representing the device FIFO.
|
||||
DeviceProxyFifo deviceFifo();
|
||||
/// @return A @ref FifoDeviceHandle object representing the device FIFO.
|
||||
FifoDeviceHandle deviceHandle();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
|
||||
83
include/mscclpp/fifo_device.hpp
Normal file
83
include/mscclpp/fifo_device.hpp
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef MSCCLPP_FIFO_DEVICE_HPP_
|
||||
#define MSCCLPP_FIFO_DEVICE_HPP_
|
||||
|
||||
#include "poll.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy.
|
||||
///
|
||||
/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push
|
||||
/// ProxyTrigger elements and a single host proxy thread consumes these work elements.
|
||||
///
|
||||
/// Do not use the most significant bit of @ref snd as it is reserved for memory consistency purposes
|
||||
struct alignas(16) ProxyTrigger {
|
||||
uint64_t fst, snd;
|
||||
};
|
||||
|
||||
/// A concurrent FIFO where multiple device threads can push work elements and a single host proxy thread consumes them.
|
||||
///
|
||||
/// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost
|
||||
/// infinity. There are two copies of the tail, one on the device, @ref FifoDeviceHandle::tailReplica, and another on
|
||||
/// the host, namely, hostTail. The host always has the "true" tail and occasionally pushes it to the copy on the
|
||||
/// device. Therefore, most of the time, the device has a stale version. The invariants are: tailReplica <= hostTail <=
|
||||
/// head. The @ref push() function increments head, hostTail is updated in @ref Fifo::pop(), and it occasionally flushes
|
||||
/// it to tailReplica via @ref Fifo::flushTail().
|
||||
///
|
||||
/// Duplicating the tail is a good idea because the FIFO is large enough, and we do not need frequent updates for the
|
||||
/// tail as there is usually enough space for device threads to push their work into.
|
||||
///
|
||||
struct FifoDeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
/// Push a trigger to the FIFO.
|
||||
///
|
||||
/// @param trigger The trigger to push.
|
||||
/// @return The new head of the FIFO.
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
|
||||
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);
|
||||
// make the last bit intentionally non-zero so that we can safely poll. Don't worry, we will change it back in host
|
||||
// side
|
||||
trigger.snd ^= ((uint64_t)1 << (uint64_t)63);
|
||||
|
||||
// Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to
|
||||
// write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but
|
||||
// for the second condition we need to read CPU memory.
|
||||
// As volatile access is slow, we first check using the bare pointer and then use the volatile pointer if the
|
||||
// condition is not met.
|
||||
if (curFifoHead >= size + *(this->tailReplica)) {
|
||||
OR_POLL_MAYBE_JAILBREAK(curFifoHead >= size + *((volatile uint64_t*)this->tailReplica),
|
||||
*(volatile uint64_t*)&this->triggers[curFifoHead % size] != 0, 1000000);
|
||||
}
|
||||
|
||||
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % size]);
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
|
||||
return curFifoHead;
|
||||
}
|
||||
|
||||
/// Wait until there is a place in the FIFO to push a trigger.
|
||||
///
|
||||
/// @param curFifoHead The current head of the FIFO.
|
||||
__forceinline__ __device__ void sync(uint64_t curFifoHead) {
|
||||
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
|
||||
// to wait for cudaMemcpy to be done.
|
||||
OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % size]) != 0,
|
||||
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
/// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`.
|
||||
ProxyTrigger* triggers;
|
||||
/// Replica of the FIFO tail that is allocated on device.
|
||||
uint64_t* tailReplica;
|
||||
/// The FIFO head. Allocated on the device and only accessed by the device.
|
||||
uint64_t* head;
|
||||
/// The FIFO size.
|
||||
int size;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_FIFO_DEVICE_HPP_
|
||||
@@ -4,6 +4,8 @@
|
||||
#ifndef MSCCLPP_PACKET_HPP_
|
||||
#define MSCCLPP_PACKET_HPP_
|
||||
|
||||
#include "poll.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// LL (low latency) protocol packet.
|
||||
@@ -42,17 +44,24 @@ union LLPacket {
|
||||
"r"((uint32_t)(val >> 32)), "r"(flag));
|
||||
}
|
||||
|
||||
/// Helper of @ref read().
|
||||
/// @param flag The flag to read.
|
||||
/// @param data The 8-byte data read.
|
||||
/// @return True if the flag is not equal to the given flag.
|
||||
__forceinline__ __device__ bool readOnce(uint32_t flag, uint2& data) {
|
||||
uint32_t flag1, flag2;
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(data.x), "=r"(flag1), "=r"(data.y), "=r"(flag2)
|
||||
: "l"(v));
|
||||
return (flag1 != flag) || (flag2 != flag);
|
||||
}
|
||||
|
||||
/// Read 8 bytes of data from the packet.
|
||||
/// @param flag The flag to read.
|
||||
/// @return The 8-byte data read.
|
||||
__forceinline__ __device__ uint2 read(uint32_t flag) {
|
||||
uint2 data;
|
||||
uint32_t flag1, flag2;
|
||||
do {
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(data.x), "=r"(flag1), "=r"(data.y), "=r"(flag2)
|
||||
: "l"(v));
|
||||
} while ((flag1 != flag) || (flag2 != flag));
|
||||
POLL_MAYBE_JAILBREAK(readOnce(flag, data), 100000000);
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -80,6 +89,7 @@ __forceinline__ __device__ void putPackets(void* dst, uint64_t dstOffset, void*
|
||||
__forceinline__ __device__ void getPackets(void* dst, uint64_t dstOffset, void* src, uint64_t srcOffset,
|
||||
uint64_t dstBytes, uint32_t threadId, uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
// TODO(saemal): this is not matching sm_channel get method.
|
||||
LLPacket* srcBase = (LLPacket*)((char*)src + srcOffset);
|
||||
uint2* dstBase = (uint2*)((char*)dst + dstOffset);
|
||||
size_t nElem = dstBytes / sizeof(uint2);
|
||||
|
||||
@@ -6,14 +6,8 @@
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
#ifndef NDEBUG
|
||||
// TODO(chhwang): https://github.com/microsoft/mscclpp/issues/99
|
||||
#define POLL_PRINT_ON_STUCK(__cond)
|
||||
// #include <stdio.h>
|
||||
// #define POLL_PRINT_ON_STUCK(__cond) do { printf("mscclpp: spin is stuck. condition: " #__cond "\n"); } while (0);
|
||||
#else // NDEBUG
|
||||
#define POLL_PRINT_ON_STUCK(__cond)
|
||||
#endif // NDEBUG
|
||||
extern __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
|
||||
const char *__function) __THROW;
|
||||
|
||||
// If a spin is stuck, escape from it and set status to 1.
|
||||
#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \
|
||||
@@ -22,7 +16,6 @@
|
||||
__status = 0; \
|
||||
while (__cond) { \
|
||||
if (__spin_cnt++ == __max_spin_cnt) { \
|
||||
POLL_PRINT_ON_STUCK(__cond); \
|
||||
__status = 1; \
|
||||
break; \
|
||||
} \
|
||||
@@ -30,31 +23,31 @@
|
||||
} while (0);
|
||||
|
||||
// If a spin is stuck, print a warning and keep spinning.
|
||||
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
|
||||
do { \
|
||||
uint64_t __spin_cnt = 0; \
|
||||
while (__cond) { \
|
||||
if (__spin_cnt++ == __max_spin_cnt) { \
|
||||
POLL_PRINT_ON_STUCK(__cond); \
|
||||
} \
|
||||
} \
|
||||
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
|
||||
do { \
|
||||
uint64_t __spin_cnt = 0; \
|
||||
while (__cond) { \
|
||||
if (__spin_cnt++ == __max_spin_cnt) { \
|
||||
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
|
||||
} \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2
|
||||
// this is specially useful when __cond1 is faster to check
|
||||
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
|
||||
do { \
|
||||
uint64_t __spin_cnt = 0; \
|
||||
while (true) { \
|
||||
if (!(__cond1)) { \
|
||||
break; \
|
||||
} else if (!(__cond2)) { \
|
||||
break; \
|
||||
} \
|
||||
if (__spin_cnt++ == __max_spin_cnt) { \
|
||||
POLL_PRINT_ON_STUCK(__cond); \
|
||||
} \
|
||||
} \
|
||||
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
|
||||
do { \
|
||||
uint64_t __spin_cnt = 0; \
|
||||
while (true) { \
|
||||
if (!(__cond1)) { \
|
||||
break; \
|
||||
} else if (!(__cond2)) { \
|
||||
break; \
|
||||
} \
|
||||
if (__spin_cnt++ == __max_spin_cnt) { \
|
||||
__assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
|
||||
} \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
@@ -28,7 +28,7 @@ class Proxy {
|
||||
void start();
|
||||
void stop();
|
||||
|
||||
HostProxyFifo& fifo();
|
||||
Fifo& fifo();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
|
||||
@@ -5,18 +5,12 @@
|
||||
#define MSCCLPP_PROXY_CHANNEL_HPP_
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <mscclpp/proxy_channel_device.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
using SemaphoreId = uint32_t;
|
||||
|
||||
/// Numeric ID of @ref RegisteredMemory. @ref ProxyService has an internal array indexed by these handles mapping to the
|
||||
/// actual.
|
||||
using MemoryId = uint32_t;
|
||||
|
||||
struct ProxyChannel;
|
||||
|
||||
/// Base class for proxy services. Proxy services are used to proxy data between devices.
|
||||
@@ -32,13 +26,17 @@ class BaseProxyService {
|
||||
class ProxyService : public BaseProxyService {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator to use.
|
||||
ProxyService(Communicator& communicator);
|
||||
ProxyService();
|
||||
|
||||
/// Add a semaphore to the proxy service.
|
||||
/// Build and add a semaphore to the proxy service.
|
||||
/// @param connection The connection associated with the semaphore.
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId addSemaphore(std::shared_ptr<Connection> connection);
|
||||
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Add a semaphore to the proxy service.
|
||||
/// @param semaphore The semaphore to be added
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore);
|
||||
|
||||
/// Register a memory region with the proxy service.
|
||||
/// @param memory The memory region to register.
|
||||
@@ -53,7 +51,7 @@ class ProxyService : public BaseProxyService {
|
||||
/// Get a proxy channel by semaphore ID.
|
||||
/// @param id The ID of the semaphore.
|
||||
/// @return The proxy channel.
|
||||
ProxyChannel deviceChannel(SemaphoreId id);
|
||||
ProxyChannel proxyChannel(SemaphoreId id);
|
||||
|
||||
/// Start the proxy service.
|
||||
void startProxy();
|
||||
@@ -62,7 +60,6 @@ class ProxyService : public BaseProxyService {
|
||||
void stopProxy();
|
||||
|
||||
private:
|
||||
Communicator& communicator_;
|
||||
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<RegisteredMemory> memories_;
|
||||
Proxy proxy_;
|
||||
@@ -73,170 +70,44 @@ class ProxyService : public BaseProxyService {
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw);
|
||||
};
|
||||
|
||||
using TriggerType = uint64_t;
|
||||
const TriggerType TriggerData = 0x1; // Trigger a data transfer.
|
||||
const TriggerType TriggerFlag = 0x2; // Trigger a signaling.
|
||||
const TriggerType TriggerSync = 0x4; // Trigger a flush.
|
||||
|
||||
#define MSCCLPP_BITS_SIZE 32
|
||||
#define MSCCLPP_BITS_OFFSET 32
|
||||
#define MSCCLPP_BITS_REGMEM_HANDLE 8
|
||||
#define MSCCLPP_BITS_TYPE 3
|
||||
#define MSCCLPP_BITS_CONNID 10
|
||||
|
||||
/// Basic structure of each work element in the FIFO.
|
||||
union ChannelTrigger {
|
||||
ProxyTrigger value;
|
||||
// The summation of number of bits must be 128 or less.
|
||||
struct {
|
||||
// First 64 bits: value[0]
|
||||
uint64_t size : MSCCLPP_BITS_SIZE;
|
||||
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
// Second 64 bits: value[1]
|
||||
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
|
||||
uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t chanId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE -
|
||||
MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
/// Default constructor.
|
||||
__device__ ChannelTrigger() {}
|
||||
|
||||
/// Copy constructor.
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
|
||||
/// Constructor.
|
||||
/// @param type The type of the trigger.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param bytes The bytes of the transfer.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t bytes, int semaphoreId) {
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes);
|
||||
value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
|
||||
<< MSCCLPP_BITS_REGMEM_HANDLE) +
|
||||
src)
|
||||
<< MSCCLPP_BITS_OFFSET) +
|
||||
dstOffset);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
/// Proxy channel.
|
||||
struct ProxyChannel {
|
||||
// Use DeviceHandle<ProxyChannel> in device code.
|
||||
typedef ProxyChannel DeviceHandle;
|
||||
|
||||
ProxyChannel() = default;
|
||||
|
||||
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, DeviceProxyFifo fifo);
|
||||
|
||||
ProxyChannel(const ProxyChannel& other) = default;
|
||||
|
||||
ProxyChannel& operator=(ProxyChannel& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
put(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerFlag to the FIFO.
|
||||
__forceinline__ __device__ void signal() {
|
||||
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
putWithSignal(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size) {
|
||||
uint64_t curFifoHead = fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_)
|
||||
.value);
|
||||
fifo_.sync(curFifoHead);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerSync to the FIFO.
|
||||
__forceinline__ __device__ void flush() {
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value);
|
||||
fifo_.sync(curFifoHead);
|
||||
}
|
||||
|
||||
/// Wait for the proxy channel to be signaled.
|
||||
__forceinline__ __device__ void wait() { semaphore_.wait(); }
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
private:
|
||||
SemaphoreId semaphoreId_;
|
||||
|
||||
Host2DeviceSemaphore::DeviceHandle semaphore_;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
DeviceProxyFifo fifo_;
|
||||
FifoDeviceHandle fifo_;
|
||||
|
||||
public:
|
||||
ProxyChannel() = default;
|
||||
|
||||
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, FifoDeviceHandle fifo);
|
||||
|
||||
ProxyChannel(const ProxyChannel& other) = default;
|
||||
|
||||
ProxyChannel& operator=(ProxyChannel& other) = default;
|
||||
|
||||
/// Device-side handle for @ref ProxyChannel.
|
||||
using DeviceHandle = ProxyChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the ProxyChannel is not released when using the returned handle.
|
||||
///
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
/// Simple proxy channel with a single destination and source memory region.
|
||||
struct SimpleProxyChannel {
|
||||
// Use DeviceHandle<SimpleProxyChannel> in device code.
|
||||
typedef SimpleProxyChannel DeviceHandle;
|
||||
private:
|
||||
ProxyChannel proxyChan_;
|
||||
MemoryId dst_;
|
||||
MemoryId src_;
|
||||
|
||||
public:
|
||||
/// Default constructor.
|
||||
SimpleProxyChannel() = default;
|
||||
|
||||
@@ -256,69 +127,16 @@ struct SimpleProxyChannel {
|
||||
/// Assignment operator.
|
||||
SimpleProxyChannel& operator=(SimpleProxyChannel& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
proxyChan_.put(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
/// Device-side handle for @ref SimpleProxyChannel.
|
||||
using DeviceHandle = SimpleProxyChannelDeviceHandle;
|
||||
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }
|
||||
|
||||
/// Push a @ref TriggerFlag to the FIFO.
|
||||
__forceinline__ __device__ void signal() { proxyChan_.signal(); }
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); }
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
|
||||
putWithSignalAndFlush(offset, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerSync to the FIFO.
|
||||
__forceinline__ __device__ void flush() { proxyChan_.flush(); }
|
||||
|
||||
/// Wait for the proxy channel to be signaled.
|
||||
__forceinline__ __device__ void wait() { proxyChan_.wait(); }
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
ProxyChannel proxyChan_;
|
||||
MemoryId dst_;
|
||||
MemoryId src_;
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the SimpleProxyChannel is not released when using the returned handle.
|
||||
///
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
template <>
|
||||
DeviceHandle<ProxyChannel> deviceHandle(ProxyChannel&& proxyChannel);
|
||||
|
||||
template <>
|
||||
DeviceHandle<SimpleProxyChannel> deviceHandle(SimpleProxyChannel&& simpleProxyChannel);
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_PROXY_CHANNEL_HPP_
|
||||
|
||||
227
include/mscclpp/proxy_channel_device.hpp
Normal file
227
include/mscclpp/proxy_channel_device.hpp
Normal file
@@ -0,0 +1,227 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_
|
||||
#define MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_
|
||||
|
||||
#include "fifo_device.hpp"
|
||||
#include "semaphore_device.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
using SemaphoreId = uint32_t;
|
||||
|
||||
/// Numeric ID of @ref RegisteredMemory. @ref ProxyService has an internal array indexed by these handles mapping to the
|
||||
/// actual.
|
||||
using MemoryId = uint32_t;
|
||||
|
||||
using TriggerType = uint64_t;
|
||||
const TriggerType TriggerData = 0x1; // Trigger a data transfer.
|
||||
const TriggerType TriggerFlag = 0x2; // Trigger a signaling.
|
||||
const TriggerType TriggerSync = 0x4; // Trigger a flush.
|
||||
|
||||
#define MSCCLPP_BITS_SIZE 32
|
||||
#define MSCCLPP_BITS_OFFSET 32
|
||||
#define MSCCLPP_BITS_REGMEM_HANDLE 8
|
||||
#define MSCCLPP_BITS_TYPE 3
|
||||
#define MSCCLPP_BITS_CONNID 10
|
||||
#define MSCCLPP_BITS_FIFO_RESERVED 1
|
||||
|
||||
/// Basic structure of each work element in the FIFO.
|
||||
union ChannelTrigger {
|
||||
ProxyTrigger value;
|
||||
// The summation of number of bits must be 128 or less.
|
||||
struct {
|
||||
// First 64 bits: value[0]
|
||||
uint64_t size : MSCCLPP_BITS_SIZE;
|
||||
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
// Second 64 bits: value[1]
|
||||
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
|
||||
uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t chanId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_TYPE -
|
||||
MSCCLPP_BITS_CONNID - MSCCLPP_BITS_FIFO_RESERVED); // ensure 64-bit alignment
|
||||
uint64_t reserved : MSCCLPP_BITS_FIFO_RESERVED;
|
||||
} fields;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
/// Default constructor.
|
||||
__forceinline__ __device__ ChannelTrigger() {}
|
||||
|
||||
/// Copy constructor.
|
||||
__forceinline__ __device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
|
||||
/// Constructor.
|
||||
/// @param type The type of the trigger.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param bytes The bytes of the transfer.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
__forceinline__ __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t bytes, int semaphoreId) {
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes);
|
||||
value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
|
||||
<< MSCCLPP_BITS_REGMEM_HANDLE) +
|
||||
src)
|
||||
<< MSCCLPP_BITS_OFFSET) +
|
||||
dstOffset);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct ProxyChannelDeviceHandle {
|
||||
SemaphoreId semaphoreId_;
|
||||
|
||||
Host2DeviceSemaphoreDeviceHandle semaphore_;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
FifoDeviceHandle fifo_;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
put(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerFlag to the FIFO.
|
||||
__forceinline__ __device__ void signal() {
|
||||
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
putWithSignal(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size) {
|
||||
uint64_t curFifoHead = fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_)
|
||||
.value);
|
||||
fifo_.sync(curFifoHead);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerSync to the FIFO.
|
||||
__forceinline__ __device__ void flush() {
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value);
|
||||
fifo_.sync(curFifoHead);
|
||||
}
|
||||
|
||||
/// Wait for the proxy channel to be signaled.
|
||||
__forceinline__ __device__ void wait() { semaphore_.wait(); }
|
||||
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct SimpleProxyChannelDeviceHandle {
|
||||
ProxyChannelDeviceHandle proxyChan_;
|
||||
MemoryId dst_;
|
||||
MemoryId src_;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
proxyChan_.put(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData to the FIFO.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }
|
||||
|
||||
/// Push a @ref TriggerFlag to the FIFO.
|
||||
__forceinline__ __device__ void signal() { proxyChan_.signal(); }
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); }
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
|
||||
putWithSignalAndFlush(offset, offset, size);
|
||||
}
|
||||
|
||||
/// Push a @ref TriggerSync to the FIFO.
|
||||
__forceinline__ __device__ void flush() { proxyChan_.flush(); }
|
||||
|
||||
/// Wait for the proxy channel to be signaled.
|
||||
__forceinline__ __device__ void wait() { proxyChan_.wait(); }
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/cuda_utils.hpp>
|
||||
#include <mscclpp/poll.hpp>
|
||||
#include <mscclpp/semaphore_device.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -81,18 +82,7 @@ class Host2DeviceSemaphore : public BaseSemaphore<CudaDeleter, std::default_dele
|
||||
void signal();
|
||||
|
||||
/// Device-side handle for @ref Host2DeviceSemaphore.
|
||||
struct DeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
/// Wait for the host to signal.
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*expectedInboundSemaphoreId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), 1000000);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
uint64_t* inboundSemaphoreId;
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
};
|
||||
using DeviceHandle = Host2DeviceSemaphoreDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
DeviceHandle deviceHandle();
|
||||
@@ -133,50 +123,7 @@ class SmDevice2DeviceSemaphore : public BaseSemaphore<CudaDeleter, CudaDeleter>
|
||||
SmDevice2DeviceSemaphore() = default;
|
||||
|
||||
/// Device-side handle for @ref SmDevice2DeviceSemaphore.
|
||||
struct DeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
/// Wait for the remote device to signal.
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*expectedInboundSemaphoreId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 1000000);
|
||||
}
|
||||
|
||||
/// Signal the remote device.
|
||||
///
|
||||
/// This function guarantees that all the memory operation before this function is completed before the remote
|
||||
/// semaphore is signaled.
|
||||
///
|
||||
__forceinline__ __device__ void signal() {
|
||||
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
|
||||
// `outboundSemaphoreId` is visible.
|
||||
__threadfence_system();
|
||||
semaphoreIncrement();
|
||||
*remoteInboundSemaphoreId = semaphoreGetLocal();
|
||||
}
|
||||
|
||||
/// Signal the remote device for copied packets.
|
||||
///
|
||||
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
|
||||
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
|
||||
/// completion of copies.
|
||||
///
|
||||
__forceinline__ __device__ void signalPacket() {
|
||||
semaphoreIncrement();
|
||||
*remoteInboundSemaphoreId = semaphoreGetLocal();
|
||||
}
|
||||
|
||||
/// Increase the counter of the local semaphore.
|
||||
__forceinline__ __device__ void semaphoreIncrement() { *outboundSemaphoreId += 1; }
|
||||
|
||||
/// Get the value of the local semaphore.
|
||||
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; }
|
||||
#endif // __CUDACC__
|
||||
|
||||
volatile uint64_t* inboundSemaphoreId;
|
||||
uint64_t* outboundSemaphoreId;
|
||||
volatile uint64_t* remoteInboundSemaphoreId;
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
};
|
||||
using DeviceHandle = SmDevice2DeviceSemaphoreDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
|
||||
73
include/mscclpp/semaphore_device.hpp
Normal file
73
include/mscclpp/semaphore_device.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef MSCCLPP_SEMAPHORE_DEVICE_HPP_
|
||||
#define MSCCLPP_SEMAPHORE_DEVICE_HPP_
|
||||
|
||||
#include "poll.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// Device-side handle for @ref Host2DeviceSemaphore.
|
||||
struct Host2DeviceSemaphoreDeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
/// Wait for the host to signal.
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*expectedInboundSemaphoreId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), 100000000);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
uint64_t* inboundSemaphoreId;
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
};
|
||||
|
||||
/// Device-side handle for @ref SmDevice2DeviceSemaphore.
|
||||
struct SmDevice2DeviceSemaphoreDeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
/// Wait for the remote device to signal.
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*expectedInboundSemaphoreId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 100000000);
|
||||
}
|
||||
|
||||
/// Signal the remote device.
|
||||
///
|
||||
/// This function guarantees that all the memory operation before this function is completed before the remote
|
||||
/// semaphore is signaled.
|
||||
///
|
||||
__forceinline__ __device__ void signal() {
|
||||
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
|
||||
// `outboundSemaphoreId` is visible.
|
||||
__threadfence_system();
|
||||
semaphoreIncrement();
|
||||
*remoteInboundSemaphoreId = semaphoreGetLocal();
|
||||
}
|
||||
|
||||
/// Signal the remote device for copied packets.
|
||||
///
|
||||
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
|
||||
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
|
||||
/// completion of copies.
|
||||
///
|
||||
__forceinline__ __device__ void signalPacket() {
|
||||
semaphoreIncrement();
|
||||
*remoteInboundSemaphoreId = semaphoreGetLocal();
|
||||
}
|
||||
|
||||
/// Increase the counter of the local semaphore.
|
||||
__forceinline__ __device__ void semaphoreIncrement() { *outboundSemaphoreId += 1; }
|
||||
|
||||
/// Get the value of the local semaphore.
|
||||
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; }
|
||||
#endif // __CUDACC__
|
||||
|
||||
volatile uint64_t* inboundSemaphoreId;
|
||||
uint64_t* outboundSemaphoreId;
|
||||
volatile uint64_t* remoteInboundSemaphoreId;
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_SEMAPHORE_DEVICE_HPP_
|
||||
@@ -5,8 +5,8 @@
|
||||
#define MSCCLPP_SM_CHANNEL_HPP_
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/packet.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
#include <mscclpp/sm_channel_device.hpp>
|
||||
#include <type_traits>
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -31,305 +31,8 @@ struct SmChannel {
|
||||
SmChannel(std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
|
||||
void* getPacketBuffer = nullptr);
|
||||
|
||||
struct DeviceHandle {
|
||||
SmDevice2DeviceSemaphore::DeviceHandle semaphore_;
|
||||
void* src_;
|
||||
void* dst_;
|
||||
void* getPacketBuffer_;
|
||||
|
||||
private:
|
||||
#ifdef __CUDACC__
|
||||
/// Helper for aligned data type access.
|
||||
/// @tparam T The data type.
|
||||
template <typename T>
|
||||
struct Element {
|
||||
static constexpr bool is4B = (sizeof(T) == 4);
|
||||
static constexpr bool is8B = (sizeof(T) == 8);
|
||||
static constexpr bool is4Bx2 =
|
||||
(std::is_same<T, int2>::value || std::is_same<T, uint2>::value || std::is_same<T, float2>::value);
|
||||
static constexpr bool is4Bx4 =
|
||||
(std::is_same<T, int4>::value || std::is_same<T, uint4>::value || std::is_same<T, float4>::value);
|
||||
static constexpr bool is8Bx2 =
|
||||
(std::is_same<T, longlong2>::value || std::is_same<T, ulonglong2>::value || std::is_same<T, double2>::value);
|
||||
// Note: we do not support long2 and ulong2 as their size may differ on different platforms.
|
||||
static constexpr bool isValid = (is4B || is8B || is4Bx2 || is4Bx4 || is8Bx2);
|
||||
|
||||
/// Load an element from DRAM.
|
||||
///
|
||||
/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's
|
||||
/// responsibility.
|
||||
///
|
||||
/// @param v The value to be loaded.
|
||||
/// @param p The address of the value to be loaded.
|
||||
///
|
||||
static __forceinline__ __device__ void load(T& v, const T* p) {
|
||||
if constexpr (is4B) {
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory");
|
||||
} else if constexpr (is8B) {
|
||||
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory");
|
||||
} else if constexpr (is4Bx2) {
|
||||
asm volatile("ld.volatile.global.v2.u32 {%0,%1}, [%2];" : "=r"(v.x), "=r"(v.y) : "l"(p) : "memory");
|
||||
} else if constexpr (is4Bx4) {
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z)
|
||||
: "l"(p)
|
||||
: "memory");
|
||||
} else if constexpr (is8Bx2) {
|
||||
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
|
||||
}
|
||||
static_assert(isValid, "Unsupported type T");
|
||||
}
|
||||
|
||||
/// Write an element on DRAM.
|
||||
///
|
||||
/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's
|
||||
/// responsibility.
|
||||
///
|
||||
/// @param p The address of the value to be written.
|
||||
/// @param v The value to be written.
|
||||
///
|
||||
static __forceinline__ __device__ void store(T* p, const T& v) {
|
||||
if constexpr (is4B) {
|
||||
asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory");
|
||||
} else if constexpr (is8B) {
|
||||
asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory");
|
||||
} else if constexpr (is4Bx2) {
|
||||
asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" : : "l"(p), "r"(v.x), "r"(v.y) : "memory");
|
||||
} else if constexpr (is4Bx4) {
|
||||
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
|
||||
:
|
||||
: "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z)
|
||||
: "memory");
|
||||
} else if constexpr (is8Bx2) {
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory");
|
||||
}
|
||||
static_assert(isValid, "Unsupported type T");
|
||||
}
|
||||
|
||||
/// Copy aligned elements from the source memory to the destination memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
|
||||
/// elements.
|
||||
///
|
||||
/// @param dst The destination address.
|
||||
/// @param src The source address.
|
||||
/// @param numElems The number of elements to be copied.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different
|
||||
/// from the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
static __forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
T reg;
|
||||
for (size_t i = threadId; i < numElems; i += numThreads) {
|
||||
// Load to register first.
|
||||
load(reg, src + i);
|
||||
store(dst + i, reg);
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif // __CUDACC__
|
||||
public:
|
||||
#ifdef __CUDACC__
|
||||
/// Load a value from the remote memory.
|
||||
/// @tparam T The type of the value to be loaded.
|
||||
/// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T).
|
||||
/// @return The value loaded.
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T read(uint64_t index) {
|
||||
T v;
|
||||
Element<T>::load(v, (T*)dst_ + index);
|
||||
return v;
|
||||
}
|
||||
|
||||
/// Write a value to the remote memory.
|
||||
/// @tparam T The type of the value to be written.
|
||||
/// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T).
|
||||
/// @param v The value to be written.
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void write(uint64_t index, const T& v) {
|
||||
Element<T>::store((T*)dst_ + index, v);
|
||||
}
|
||||
|
||||
/// Copy aligned data from the source memory to the destination memory.
|
||||
///
|
||||
/// This function is a warpper of Element<T>::copy(). Unlike Element<T>::copy(), this function can copy remainder
|
||||
/// bytes when @p CopyRemainder is true. Still, the copying bytes must be a multiple of 4.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src.
|
||||
/// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 4, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
|
||||
static_assert(Alignment == 4 || Alignment == 8 || Alignment % 16 == 0, "Unsupported alignment");
|
||||
using Type =
|
||||
typename std::conditional<Alignment == 4, int,
|
||||
typename std::conditional<Alignment == 8, long long, longlong2>::type>::type;
|
||||
int* dstInt = reinterpret_cast<int*>(dst);
|
||||
int* srcInt = reinterpret_cast<int*>(src);
|
||||
const uintptr_t dstPtr = reinterpret_cast<uintptr_t>(dst);
|
||||
const uintptr_t srcPtr = reinterpret_cast<uintptr_t>(src);
|
||||
const uint64_t numInt = bytes / sizeof(int);
|
||||
Type* dstElem = reinterpret_cast<Type*>((dstPtr + sizeof(Type) - 1) / sizeof(Type) * sizeof(Type));
|
||||
Type* srcElem = reinterpret_cast<Type*>((srcPtr + sizeof(Type) - 1) / sizeof(Type) * sizeof(Type));
|
||||
uint64_t nFirstInt = (reinterpret_cast<uintptr_t>(dstElem) - dstPtr) / sizeof(int);
|
||||
if (CopyRemainder) {
|
||||
// Copy the remainder integers at the beginning.
|
||||
Element<int>::copy(dstInt, srcInt, nFirstInt, threadId, numThreads);
|
||||
}
|
||||
// Copy elements.
|
||||
constexpr uint64_t nIntPerElem = sizeof(Type) / sizeof(int);
|
||||
uint64_t nElem = (numInt - nFirstInt) / nIntPerElem;
|
||||
Element<Type>::copy(dstElem, srcElem, nElem, threadId, numThreads);
|
||||
if (CopyRemainder && nIntPerElem > 1) {
|
||||
// Copy the remainder integers at the end.
|
||||
uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem;
|
||||
Element<int>::copy(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt,
|
||||
threadId, numThreads);
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy data from the local memory to the remote memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
|
||||
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
copy<Alignment, CopyRemainder>((char*)dst_ + dstOffset, (char*)src_ + srcOffset, bytes, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Copy data from the remote memory to the local memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
|
||||
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void get(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
// Note that `dst` and `src` are swapped for `get()`.
|
||||
copy<Alignment, CopyRemainder>((char*)src_ + srcOffset, (char*)dst_ + dstOffset, bytes, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Copy data from the local memory to the remote memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
|
||||
put<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Copy data from the remote memory to the local memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void get(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
|
||||
get<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Construct @ref LLPacket from the data in the local memory and write it on the remote memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of packets.
|
||||
///
|
||||
/// @param dstOffset The offset in bytes of the remote address.
|
||||
/// @param srcOffset The offset in bytes of the local address.
|
||||
/// @param bytes Bytes of the data to be copied.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
__forceinline__ __device__ void putPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes,
|
||||
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
|
||||
mscclpp::putPackets(dst_, dstOffset, src_, srcOffset, bytes, threadId, numThreads, flag);
|
||||
}
|
||||
|
||||
/// Retrieve data from @ref LLPacket in the local packet buffer and write it on the local memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @param dstOffset The offset in bytes of the local memory.
|
||||
/// @param srcOffset The offset in bytes of the local packet buffer.
|
||||
/// @param bytes Bytes of the data to be copied.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
__forceinline__ __device__ void getPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes,
|
||||
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
|
||||
mscclpp::getPackets(src_, dstOffset, getPacketBuffer_, srcOffset, bytes, threadId, numThreads, flag);
|
||||
}
|
||||
|
||||
/// Signal the remote semaphore.
|
||||
///
|
||||
/// This function guarantees that all the memory operation before this function is completed before the remote
|
||||
/// semaphore is signaled.
|
||||
///
|
||||
__forceinline__ __device__ void signal() { semaphore_.signal(); }
|
||||
|
||||
/// Signal the remote semaphore for copied packets.
|
||||
///
|
||||
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
|
||||
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
|
||||
/// completion of copies.
|
||||
///
|
||||
__forceinline__ __device__ void signalPacket() { semaphore_.signalPacket(); }
|
||||
|
||||
/// Increase the counter of the local semaphore.
|
||||
__forceinline__ __device__ void semaphoreIncrement() { semaphore_.semaphoreIncrement(); }
|
||||
|
||||
/// Read the counter of the local semaphore.
|
||||
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
|
||||
|
||||
/// Wait for the remote semaphore to send a signal.
|
||||
__forceinline__ __device__ void wait() { semaphore_.wait(); }
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
/// Device-side handle for @ref SmChannel.
|
||||
using DeviceHandle = SmChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
|
||||
336
include/mscclpp/sm_channel_device.hpp
Normal file
336
include/mscclpp/sm_channel_device.hpp
Normal file
@@ -0,0 +1,336 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef MSCCLPP_SM_CHANNEL_DEVICE_HPP_
|
||||
#define MSCCLPP_SM_CHANNEL_DEVICE_HPP_
|
||||
|
||||
#include "packet.hpp"
|
||||
#include "poll.hpp"
|
||||
#include "semaphore_device.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
namespace Element {
|
||||
|
||||
/// Load an element from DRAM.
|
||||
///
|
||||
/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's
|
||||
/// responsibility.
|
||||
///
|
||||
/// @param v The value to be loaded.
|
||||
/// @param p The address of the value to be loaded.
|
||||
///
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void load(T& v, const T* p) {
|
||||
// We should only use the specialized functions.
|
||||
__assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
/// Write an element on DRAM.
|
||||
///
|
||||
/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's
|
||||
/// responsibility.
|
||||
///
|
||||
/// @param p The address of the value to be written.
|
||||
/// @param v The value to be written.
|
||||
///
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void store(T* p, const T& v) {
|
||||
// We should only use the specialized functions.
|
||||
__assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
/// Copy aligned elements from the source memory to the destination memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
|
||||
/// elements.
|
||||
///
|
||||
/// @param dst The destination address.
|
||||
/// @param src The source address.
|
||||
/// @param numElems The number of elements to be copied.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different
|
||||
/// from the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId, uint32_t numThreads) {
|
||||
T reg;
|
||||
for (size_t i = threadId; i < numElems; i += numThreads) {
|
||||
// Load to register first.
|
||||
load(reg, src + i);
|
||||
store(dst + i, reg);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void load<long long>(long long& v, const long long* p) {
|
||||
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory");
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void store<long long>(long long* p, const long long& v) {
|
||||
asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory");
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void load<int>(int& v, const int* p) {
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory");
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void store<int>(int* p, const int& v) {
|
||||
asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory");
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void load<longlong2>(longlong2& v, const longlong2* p) {
|
||||
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void store<longlong2>(longlong2* p, const longlong2& v) {
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory");
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void load<int4>(int4& v, const int4* p) {
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z)
|
||||
: "l"(p)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ void store<int4>(int4* p, const int4& v) {
|
||||
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
|
||||
:
|
||||
: "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
} // namespace Element
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
/// Channel for accessing peer memory directly from SM.
|
||||
struct SmChannelDeviceHandle {
|
||||
SmDevice2DeviceSemaphoreDeviceHandle semaphore_;
|
||||
void* src_;
|
||||
void* dst_;
|
||||
void* getPacketBuffer_;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
/// Load a value from the remote memory.
|
||||
/// @tparam T The type of the value to be loaded.
|
||||
/// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T).
|
||||
/// @return The value loaded.
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T read(uint64_t index) {
|
||||
T v;
|
||||
Element::load<T>(v, (T*)dst_ + index);
|
||||
return v;
|
||||
}
|
||||
|
||||
/// Write a value to the remote memory.
|
||||
/// @tparam T The type of the value to be written.
|
||||
/// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T).
|
||||
/// @param v The value to be written.
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void write(uint64_t index, const T& v) {
|
||||
Element::store<T>((T*)dst_ + index, v);
|
||||
}
|
||||
|
||||
/// this is a helper for copy function
|
||||
template <typename T, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void copy_helper(void* dst, void* src, uint64_t bytes, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
int* dstInt = reinterpret_cast<int*>(dst);
|
||||
int* srcInt = reinterpret_cast<int*>(src);
|
||||
const uintptr_t dstPtr = reinterpret_cast<uintptr_t>(dst);
|
||||
const uintptr_t srcPtr = reinterpret_cast<uintptr_t>(src);
|
||||
const uint64_t numInt = bytes / sizeof(int);
|
||||
T* dstElem = reinterpret_cast<T*>((dstPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
|
||||
T* srcElem = reinterpret_cast<T*>((srcPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
|
||||
uint64_t nFirstInt = (reinterpret_cast<uintptr_t>(dstElem) - dstPtr) / sizeof(int);
|
||||
if (CopyRemainder) {
|
||||
// Copy the remainder integers at the beginning.
|
||||
Element::copy<int>(dstInt, srcInt, nFirstInt, threadId, numThreads);
|
||||
}
|
||||
// Copy elements.
|
||||
constexpr uint64_t nIntPerElem = sizeof(T) / sizeof(int);
|
||||
uint64_t nElem = (numInt - nFirstInt) / nIntPerElem;
|
||||
Element::copy<T>(dstElem, srcElem, nElem, threadId, numThreads);
|
||||
if (CopyRemainder && nIntPerElem > 1) {
|
||||
// Copy the remainder integers at the end.
|
||||
uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem;
|
||||
Element::copy<int>(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt,
|
||||
threadId, numThreads);
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy aligned data from the source memory to the destination memory.
|
||||
///
|
||||
/// This function is a warpper of Element<T>::copy(). Unlike Element<T>::copy(), this function can copy remainder
|
||||
/// bytes when @p CopyRemainder is true. Still, the 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src.
|
||||
/// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
|
||||
if (Alignment == 4) {
|
||||
copy_helper<int, CopyRemainder>(dst, src, bytes, threadId, numThreads);
|
||||
} else if (Alignment == 8) {
|
||||
copy_helper<long long, CopyRemainder>(dst, src, bytes, threadId, numThreads);
|
||||
} else if (Alignment == 16) {
|
||||
copy_helper<longlong2, CopyRemainder>(dst, src, bytes, threadId, numThreads);
|
||||
} else {
|
||||
static_assert(Alignment == 4 || Alignment == 8 || Alignment == 16, "Unsupported alignment");
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy data from the local memory to the remote memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
|
||||
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
copy<Alignment, CopyRemainder>((char*)dst_ + dstOffset, (char*)src_ + srcOffset, bytes, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Copy data from the remote memory to the local memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
|
||||
/// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void get(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
// Note that `dst` and `src` are swapped for `get()`.
|
||||
copy<Alignment, CopyRemainder>((char*)src_ + srcOffset, (char*)dst_ + dstOffset, bytes, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Copy data from the local memory to the remote memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
|
||||
put<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Copy data from the remote memory to the local memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
|
||||
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
|
||||
/// Alignment.
|
||||
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
|
||||
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
template <int Alignment = 16, bool CopyRemainder = true>
|
||||
__forceinline__ __device__ void get(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
|
||||
get<Alignment, CopyRemainder>(offset, offset, size, threadId, numThreads);
|
||||
}
|
||||
|
||||
/// Construct @ref LLPacket from the data in the local memory and write it on the remote memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of packets.
|
||||
///
|
||||
/// @param dstOffset The offset in bytes of the remote address.
|
||||
/// @param srcOffset The offset in bytes of the local address.
|
||||
/// @param bytes Bytes of the data to be copied.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
__forceinline__ __device__ void putPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
mscclpp::putPackets(dst_, dstOffset, src_, srcOffset, bytes, threadId, numThreads, flag);
|
||||
}
|
||||
|
||||
/// Retrieve data from @ref LLPacket in the local packet buffer and write it on the local memory.
|
||||
///
|
||||
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
|
||||
///
|
||||
/// @param dstOffset The offset in bytes of the local memory.
|
||||
/// @param srcOffset The offset in bytes of the local packet buffer.
|
||||
/// @param bytes Bytes of the data to be copied.
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
///
|
||||
__forceinline__ __device__ void getPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
mscclpp::getPackets(src_, dstOffset, getPacketBuffer_, srcOffset, bytes, threadId, numThreads, flag);
|
||||
}
|
||||
|
||||
/// Signal the remote semaphore.
|
||||
///
|
||||
/// This function guarantees that all the memory operation before this function is completed before the remote
|
||||
/// semaphore is signaled.
|
||||
///
|
||||
__forceinline__ __device__ void signal() { semaphore_.signal(); }
|
||||
|
||||
/// Signal the remote semaphore for copied packets.
|
||||
///
|
||||
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
|
||||
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
|
||||
/// completion of copies.
|
||||
///
|
||||
__forceinline__ __device__ void signalPacket() { semaphore_.signalPacket(); }
|
||||
|
||||
/// Increase the counter of the local semaphore.
|
||||
__forceinline__ __device__ void semaphoreIncrement() { semaphore_.semaphoreIncrement(); }
|
||||
|
||||
/// Read the counter of the local semaphore.
|
||||
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
|
||||
|
||||
/// Wait for the remote semaphore to send a signal.
|
||||
__forceinline__ __device__ void wait() { semaphore_.wait(); }
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_SM_CHANNEL_DEVICE_HPP_
|
||||
@@ -17,6 +17,7 @@ struct Timer {
|
||||
|
||||
~Timer();
|
||||
|
||||
/// Returns the elapsed time in milliseconds.
|
||||
int64_t elapsed() const;
|
||||
|
||||
void set(int timeout);
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include <mscclpp/config.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_config(nb::module_& m) {
|
||||
nb::class_<Config>(m, "Config")
|
||||
.def_static("get_instance", &Config::getInstance, nb::rv_policy::reference)
|
||||
.def("get_bootstrap_connection_timeout_config", &Config::getBootstrapConnectionTimeoutConfig)
|
||||
.def("set_bootstrap_connection_timeout_config", &Config::setBootstrapConnectionTimeoutConfig);
|
||||
}
|
||||
@@ -17,8 +17,8 @@ extern void register_proxy_channel(nb::module_& m);
|
||||
extern void register_sm_channel(nb::module_& m);
|
||||
extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_config(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
@@ -62,9 +62,10 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("nRanks"))
|
||||
.def("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
||||
.def("initialize", (void (TcpBootstrap::*)(UniqueId)) & TcpBootstrap::initialize, nb::arg("uniqueId"))
|
||||
.def("initialize", (void (TcpBootstrap::*)(const std::string&)) & TcpBootstrap::initialize,
|
||||
nb::arg("ifIpPortTrio"));
|
||||
.def("initialize", (void (TcpBootstrap::*)(UniqueId, int64_t)) & TcpBootstrap::initialize, nb::arg("uniqueId"),
|
||||
nb::arg("timeoutSec") = 30)
|
||||
.def("initialize", (void (TcpBootstrap::*)(const std::string&, int64_t)) & TcpBootstrap::initialize,
|
||||
nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
@@ -118,7 +119,7 @@ void register_core(nb::module_& m) {
|
||||
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
|
||||
},
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush)
|
||||
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("remote_rank", &Connection::remoteRank)
|
||||
.def("tag", &Connection::tag)
|
||||
.def("transport", &Connection::transport)
|
||||
@@ -139,7 +140,8 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("transport"))
|
||||
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
|
||||
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
|
||||
.def("setup", &Communicator::setup);
|
||||
}
|
||||
|
||||
@@ -149,7 +151,7 @@ NB_MODULE(_mscclpp, m) {
|
||||
register_sm_channel(m);
|
||||
register_fifo(m);
|
||||
register_semaphore(m);
|
||||
register_config(m);
|
||||
register_utils(m);
|
||||
register_core(m);
|
||||
register_numa(m);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import mscclpp
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import logging
|
||||
import torch
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
|
||||
import mscclpp
|
||||
import torch
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB0,
|
||||
mscclpp.Transport.IB1,
|
||||
@@ -19,15 +20,19 @@ IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB7,
|
||||
]
|
||||
|
||||
# Use to hold the sm channels so they don't get garbage collected
|
||||
sm_channels = []
|
||||
|
||||
|
||||
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
simple_proxy_channels = []
|
||||
sm_semaphores = []
|
||||
connections = []
|
||||
remote_memories = []
|
||||
memory = torch.zeros(element_size, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
|
||||
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
|
||||
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
||||
@@ -42,15 +47,26 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
remote_memories.append(remote_mem)
|
||||
comm.setup()
|
||||
|
||||
# Create simple proxy channels
|
||||
for i, conn in enumerate(connections):
|
||||
proxy_channel = mscclpp.SimpleProxyChannel(
|
||||
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
|
||||
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)),
|
||||
proxy_service.add_memory(remote_memories[i].get()),
|
||||
proxy_service.add_memory(reg_mem),
|
||||
)
|
||||
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
|
||||
comm.setup()
|
||||
return simple_proxy_channels
|
||||
|
||||
# Create sm channels
|
||||
for i, conn in enumerate(connections):
|
||||
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
|
||||
sm_semaphores.append(sm_chan)
|
||||
comm.setup()
|
||||
|
||||
for i, conn in enumerate(sm_semaphores):
|
||||
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
|
||||
sm_channels.append(sm_chan)
|
||||
return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels]
|
||||
|
||||
|
||||
def run(rank, args):
|
||||
@@ -60,7 +76,7 @@ def run(rank, args):
|
||||
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
comm = mscclpp.Communicator(boot)
|
||||
proxy_service = mscclpp.ProxyService(comm)
|
||||
proxy_service = mscclpp.ProxyService()
|
||||
|
||||
logging.info("Rank: %d, setting up connections", rank)
|
||||
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main():
|
||||
config = mscclpp.Config.get_instance()
|
||||
config.set_bootstrap_connection_timeout_config(15)
|
||||
timeout = config.get_bootstrap_connection_timeout_config()
|
||||
assert timeout == 15
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import mscclpp
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.root:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import mscclpp
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main():
|
||||
timer = mscclpp.Timer()
|
||||
|
||||
@@ -11,15 +11,20 @@ using namespace mscclpp;
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
|
||||
|
||||
nb::class_<DeviceProxyFifo>(m, "DeviceProxyFifo")
|
||||
.def_rw("triggers", &DeviceProxyFifo::triggers)
|
||||
.def_rw("tail_replica", &DeviceProxyFifo::tailReplica)
|
||||
.def_rw("head", &DeviceProxyFifo::head);
|
||||
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
|
||||
.def_rw("triggers", &FifoDeviceHandle::triggers)
|
||||
.def_rw("tail_replica", &FifoDeviceHandle::tailReplica)
|
||||
.def_rw("head", &FifoDeviceHandle::head)
|
||||
.def_rw("size", &FifoDeviceHandle::size)
|
||||
.def_prop_ro("raw", [](const FifoDeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<HostProxyFifo>(m, "HostProxyFifo")
|
||||
.def(nb::init<>())
|
||||
.def("poll", &HostProxyFifo::poll, nb::arg("trigger"))
|
||||
.def("pop", &HostProxyFifo::pop)
|
||||
.def("flush_tail", &HostProxyFifo::flushTail, nb::arg("sync") = false)
|
||||
.def("device_fifo", &HostProxyFifo::deviceFifo);
|
||||
nb::class_<Fifo>(m, "Fifo")
|
||||
.def(nb::init<int>(), nb::arg("size") = 128)
|
||||
.def("poll", &Fifo::poll)
|
||||
.def("pop", &Fifo::pop)
|
||||
.def("flush_tail", &Fifo::flushTail, nb::arg("sync") = false)
|
||||
.def("size", &Fifo::size)
|
||||
.def("device_handle", &Fifo::deviceHandle);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,31 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from ._mscclpp import *
|
||||
import os as _os
|
||||
|
||||
from ._mscclpp import (
|
||||
Communicator,
|
||||
Connection,
|
||||
Fifo,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
numa,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
SimpleProxyChannel,
|
||||
SmChannel,
|
||||
SmDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
TransportFlags,
|
||||
)
|
||||
|
||||
|
||||
def get_include():
|
||||
"""Return the directory that contains the MSCCL++ headers."""
|
||||
return _os.path.join(_os.path.dirname(__file__), "include")
|
||||
|
||||
|
||||
def get_lib():
|
||||
"""Return the directory that contains the MSCCL++ headers."""
|
||||
return _os.path.join(_os.path.dirname(__file__), "lib")
|
||||
|
||||
13
python/numa_py.cpp
Normal file
13
python/numa_py.cpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace mscclpp {
|
||||
int getDeviceNumaNode(int cudaDev);
|
||||
void numaBind(int node);
|
||||
}; // namespace mscclpp
|
||||
|
||||
void register_numa(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
|
||||
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
|
||||
sub_m.def("numa_bind", &mscclpp::numaBind);
|
||||
}
|
||||
@@ -16,22 +16,40 @@ void register_proxy_channel(nb::module_& m) {
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
.def(nb::init<Communicator&>(), nb::arg("comm"))
|
||||
.def(nb::init<>())
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("connection"))
|
||||
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("device_channel", &ProxyService::deviceChannel, nb::arg("id"));
|
||||
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, DeviceProxyFifo>(), nb::arg("semaphoreId"),
|
||||
nb::arg("semaphore"), nb::arg("fifo"));
|
||||
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, FifoDeviceHandle>(), nb::arg("semaphoreId"),
|
||||
nb::arg("semaphore"), nb::arg("fifo"))
|
||||
.def("device_handle", &ProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
|
||||
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
|
||||
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src"))
|
||||
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"));
|
||||
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"))
|
||||
.def("device_handle", &SimpleProxyChannel::deviceHandle);
|
||||
|
||||
m.def("device_handle", &deviceHandle<ProxyChannel>, nb::arg("proxyChannel"));
|
||||
m.def("device_handle", &deviceHandle<SimpleProxyChannel>, nb::arg("simpleProxyChannel"));
|
||||
nb::class_<SimpleProxyChannel::DeviceHandle>(m, "SimpleProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_)
|
||||
.def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -20,7 +20,10 @@ void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
@@ -38,5 +41,8 @@ void register_semaphore(nb::module_& m) {
|
||||
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
|
||||
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const SmDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -13,11 +13,23 @@ using namespace mscclpp;
|
||||
void register_sm_channel(nb::module_& m) {
|
||||
nb::class_<SmChannel> smChannel(m, "SmChannel");
|
||||
smChannel
|
||||
.def(nb::init<std::shared_ptr<SmDevice2DeviceSemaphore>, RegisteredMemory, void*, void*>(), nb::arg("semaphore"),
|
||||
nb::arg("dst"), nb::arg("src"), nb::arg("getPacketBuffer"))
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); })
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src, uintptr_t get_packet_buffer) {
|
||||
new (smChannel) SmChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
|
||||
})
|
||||
.def("device_handle", &SmChannel::deviceHandle);
|
||||
|
||||
nb::class_<SmChannel::DeviceHandle>(smChannel, "DeviceHandle");
|
||||
|
||||
m.def("device_handle", &deviceHandle<SmChannel>, nb::arg("smChannel"));
|
||||
nb::class_<SmChannel::DeviceHandle>(m, "SmChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("src_", &SmChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SmChannel::DeviceHandle::dst_)
|
||||
.def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_)
|
||||
.def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <sys/resource.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <mscclpp/config.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <sstream>
|
||||
@@ -59,9 +58,9 @@ class TcpBootstrap::Impl {
|
||||
public:
|
||||
Impl(int rank, int nRanks);
|
||||
~Impl();
|
||||
void initialize(const UniqueId& uniqueId);
|
||||
void initialize(const std::string& ifIpPortTrio);
|
||||
void establishConnections();
|
||||
void initialize(const UniqueId& uniqueId, int64_t timeoutSec);
|
||||
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec);
|
||||
void establishConnections(int64_t timeoutSec);
|
||||
UniqueId createUniqueId();
|
||||
UniqueId getUniqueId() const;
|
||||
int getRank();
|
||||
@@ -133,15 +132,15 @@ int TcpBootstrap::Impl::getRank() { return rank_; }
|
||||
|
||||
int TcpBootstrap::Impl::getNranks() { return nRanks_; }
|
||||
|
||||
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId) {
|
||||
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) {
|
||||
netInit("", "");
|
||||
|
||||
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
|
||||
|
||||
establishConnections();
|
||||
establishConnections(timeoutSec);
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) {
|
||||
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) {
|
||||
// first check if it is a trio
|
||||
int nColons = 0;
|
||||
for (auto c : ifIpPortTrio) {
|
||||
@@ -167,7 +166,7 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) {
|
||||
bootstrapCreateRoot();
|
||||
}
|
||||
|
||||
establishConnections();
|
||||
establishConnections(timeoutSec);
|
||||
}
|
||||
|
||||
TcpBootstrap::Impl::~Impl() {
|
||||
@@ -308,8 +307,8 @@ void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface)
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
void TcpBootstrap::Impl::establishConnections() {
|
||||
const int64_t connectionTimeoutUs = (int64_t)Config::getInstance()->getBootstrapConnectionTimeoutConfig() * 1000000;
|
||||
void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) {
|
||||
const int64_t connectionTimeoutUs = timeoutSec * 1000000;
|
||||
Timer timer;
|
||||
SocketAddress nextAddr;
|
||||
ExtInfo info;
|
||||
@@ -317,6 +316,10 @@ void TcpBootstrap::Impl::establishConnections() {
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_);
|
||||
|
||||
auto getLeftTime = [&]() {
|
||||
if (connectionTimeoutUs < 0) {
|
||||
// no timeout: always return a large number
|
||||
return int64_t(1e9);
|
||||
}
|
||||
int64_t timeout = connectionTimeoutUs - timer.elapsed();
|
||||
if (timeout <= 0) throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout);
|
||||
return timeout;
|
||||
@@ -489,9 +492,13 @@ MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag)
|
||||
|
||||
MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
|
||||
|
||||
MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId, int64_t timeoutSec) {
|
||||
pimpl_->initialize(uniqueId, timeoutSec);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair) { pimpl_->initialize(ipPortPair); }
|
||||
MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair, int64_t timeoutSec) {
|
||||
pimpl_->initialize(ipPortPair, timeoutSec);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void TcpBootstrap::barrier() { pimpl_->barrier(); }
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <mscclpp/config.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <sstream>
|
||||
|
||||
@@ -94,7 +94,11 @@ MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSe
|
||||
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) {
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport,
|
||||
int ibMaxCqSize /*=1024*/,
|
||||
int ibMaxCqPollNum /*=1*/,
|
||||
int ibMaxSendWr /*=8192*/,
|
||||
int ibMaxWrPerSend /*=64*/) {
|
||||
std::shared_ptr<ConnectionBase> conn;
|
||||
if (transport == Transport::CudaIpc) {
|
||||
// sanity check: make sure the IPC connection is being made within a node
|
||||
@@ -111,7 +115,8 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
|
||||
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank,
|
||||
pimpl->rankToHash_[remoteRank]);
|
||||
} else if (AllIBTransports.has(transport)) {
|
||||
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
|
||||
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr,
|
||||
ibMaxWrPerSend, *pimpl);
|
||||
conn = ibConn;
|
||||
INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created",
|
||||
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/config.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
Config Config::instance_;
|
||||
|
||||
Config* Config::getInstance() { return &instance_; }
|
||||
|
||||
int Config::getBootstrapConnectionTimeoutConfig() { return bootstrapConnectionTimeout; }
|
||||
|
||||
void Config::setBootstrapConnectionTimeoutConfig(int timeout) { bootstrapConnectionTimeout = timeout; }
|
||||
} // namespace mscclpp
|
||||
@@ -70,21 +70,26 @@ void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset,
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void CudaIpcConnection::flush() {
|
||||
void CudaIpcConnection::flush(int64_t timeoutUsec) {
|
||||
if (timeoutUsec >= 0) {
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection flush: timeout is not supported, ignored");
|
||||
}
|
||||
AvoidCudaGraphCaptureGuard guard;
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream_));
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection flushing connection to remote rank %d", remoteRank());
|
||||
}
|
||||
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl)
|
||||
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
|
||||
int maxWrPerSend, Communicator::Impl& commImpl)
|
||||
: ConnectionBase(remoteRank, tag),
|
||||
transport_(transport),
|
||||
remoteTransport_(Transport::Unknown),
|
||||
numSignaledSends(0),
|
||||
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
|
||||
qp = commImpl.getIbContext(transport)->createQp();
|
||||
qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend);
|
||||
dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(
|
||||
dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl));
|
||||
validateTransport(dummyAtomicSourceMem_, transport);
|
||||
@@ -144,7 +149,7 @@ void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint6
|
||||
oldValue, newValue);
|
||||
}
|
||||
|
||||
void IBConnection::flush() {
|
||||
void IBConnection::flush(int64_t timeoutUsec) {
|
||||
Timer timer;
|
||||
while (numSignaledSends) {
|
||||
int wcNum = qp->pollCq();
|
||||
@@ -153,8 +158,8 @@ void IBConnection::flush() {
|
||||
}
|
||||
|
||||
auto elapsed = timer.elapsed();
|
||||
if (elapsed > MSCCLPP_POLLING_WAIT) {
|
||||
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " +
|
||||
if ((timeoutUsec >= 0) && (elapsed * 1e3 > timeoutUsec)) {
|
||||
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e3) + " seconds. Expected " +
|
||||
std::to_string(numSignaledSends) + " signals",
|
||||
ErrorCode::InternalError);
|
||||
}
|
||||
@@ -168,6 +173,7 @@ void IBConnection::flush() {
|
||||
}
|
||||
}
|
||||
}
|
||||
INFO(MSCCLPP_NET, "IBConnection flushing connection to remote rank %d", remoteRank());
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
|
||||
46
src/fifo.cc
46
src/fifo.cc
@@ -1,20 +1,18 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <emmintrin.h>
|
||||
|
||||
#include <mscclpp/cuda_utils.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct HostProxyFifo::Impl {
|
||||
struct Fifo::Impl {
|
||||
UniqueCudaHostPtr<ProxyTrigger[]> triggers;
|
||||
UniqueCudaPtr<uint64_t> head;
|
||||
UniqueCudaPtr<uint64_t> tailReplica;
|
||||
const int size;
|
||||
|
||||
// allocated on the host. Only accessed by the host. This is a copy of the
|
||||
// value pointed to by fifoTailDev and the invariant is that
|
||||
@@ -28,28 +26,33 @@ struct HostProxyFifo::Impl {
|
||||
// for transferring fifo tail
|
||||
CudaStreamWithFlags stream;
|
||||
|
||||
Impl()
|
||||
: triggers(makeUniqueCudaHost<ProxyTrigger[]>(MSCCLPP_PROXY_FIFO_SIZE)),
|
||||
Impl(int size)
|
||||
: triggers(makeUniqueCudaHost<ProxyTrigger[]>(size)),
|
||||
head(allocUniqueCuda<uint64_t>()),
|
||||
tailReplica(allocUniqueCuda<uint64_t>()),
|
||||
size(size),
|
||||
hostTail(0),
|
||||
stream(cudaStreamNonBlocking) {}
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo() : pimpl(std::make_unique<Impl>()) {}
|
||||
MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo() = default;
|
||||
MSCCLPP_API_CPP Fifo::Fifo(int size) : pimpl(std::make_unique<Impl>(size)) {}
|
||||
MSCCLPP_API_CPP Fifo::~Fifo() = default;
|
||||
|
||||
MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger) {
|
||||
__m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->triggers.get()[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
_mm_store_si128((__m128i*)trigger, xmm0);
|
||||
MSCCLPP_API_CPP ProxyTrigger Fifo::poll() {
|
||||
ProxyTrigger trigger;
|
||||
volatile ProxyTrigger* ptr =
|
||||
reinterpret_cast<volatile ProxyTrigger*>(&pimpl->triggers.get()[pimpl->hostTail % pimpl->size]);
|
||||
trigger.fst = ptr->fst;
|
||||
trigger.snd = ptr->snd;
|
||||
return trigger;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostProxyFifo::pop() {
|
||||
*(volatile uint64_t*)(&pimpl->triggers.get()[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
|
||||
MSCCLPP_API_CPP void Fifo::pop() {
|
||||
*(volatile uint64_t*)(&pimpl->triggers.get()[pimpl->hostTail % pimpl->size]) = 0;
|
||||
(pimpl->hostTail)++;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) {
|
||||
MSCCLPP_API_CPP void Fifo::flushTail(bool sync) {
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can
|
||||
// make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t),
|
||||
@@ -59,12 +62,15 @@ MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) {
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo() {
|
||||
DeviceProxyFifo deviceFifo;
|
||||
deviceFifo.triggers = pimpl->triggers.get();
|
||||
deviceFifo.head = pimpl->head.get();
|
||||
deviceFifo.tailReplica = pimpl->tailReplica.get();
|
||||
return deviceFifo;
|
||||
MSCCLPP_API_CPP int Fifo::size() const { return pimpl->size; }
|
||||
|
||||
MSCCLPP_API_CPP FifoDeviceHandle Fifo::deviceHandle() {
|
||||
FifoDeviceHandle deviceHandle;
|
||||
deviceHandle.triggers = pimpl->triggers.get();
|
||||
deviceHandle.head = pimpl->head.get();
|
||||
deviceHandle.tailReplica = pimpl->tailReplica.get();
|
||||
deviceHandle.size = pimpl->size;
|
||||
return deviceHandle;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
29
src/ib.cc
29
src/ib.cc
@@ -16,8 +16,6 @@
|
||||
#include "api.h"
|
||||
#include "debug.h"
|
||||
|
||||
#define MAXCONNECTIONS 64
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
@@ -54,8 +52,10 @@ const void* IbMr::getBuff() const { return this->buff; }
|
||||
|
||||
uint32_t IbMr::getLkey() const { return this->mr->lkey; }
|
||||
|
||||
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
|
||||
this->cq = ibv_create_cq(ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
|
||||
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
|
||||
int maxWrPerSend)
|
||||
: maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) {
|
||||
this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0);
|
||||
if (this->cq == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_cq failed (errno " << errno << ")";
|
||||
@@ -68,8 +68,8 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
|
||||
qpInitAttr.send_cq = this->cq;
|
||||
qpInitAttr.recv_cq = this->cq;
|
||||
qpInitAttr.qp_type = IBV_QPT_RC;
|
||||
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_send_wr = maxSendWr;
|
||||
qpInitAttr.cap.max_recv_wr = maxRecvWr;
|
||||
qpInitAttr.cap.max_send_sge = 1;
|
||||
qpInitAttr.cap.max_recv_sge = 1;
|
||||
qpInitAttr.cap.max_inline_data = 0;
|
||||
@@ -118,9 +118,9 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
|
||||
}
|
||||
this->qp = _qp;
|
||||
this->wrn = 0;
|
||||
this->wrs = std::make_unique<ibv_send_wr[]>(MSCCLPP_IB_MAX_SENDS);
|
||||
this->sges = std::make_unique<ibv_sge[]>(MSCCLPP_IB_MAX_SENDS);
|
||||
this->wcs = std::make_unique<ibv_wc[]>(MSCCLPP_IB_CQ_POLL_NUM);
|
||||
this->wrs = std::make_unique<ibv_send_wr[]>(maxWrPerSend);
|
||||
this->sges = std::make_unique<ibv_sge[]>(maxWrPerSend);
|
||||
this->wcs = std::make_unique<ibv_wc[]>(maxCqPollNum);
|
||||
}
|
||||
|
||||
IbQp::~IbQp() {
|
||||
@@ -182,9 +182,9 @@ void IbQp::rts() {
|
||||
}
|
||||
|
||||
IbQp::WrInfo IbQp::getNewWrInfo() {
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
if (this->wrn >= this->maxWrPerSend) {
|
||||
std::stringstream err;
|
||||
err << "too many outstanding work requests. limit is " << MSCCLPP_IB_MAX_SENDS;
|
||||
err << "too many outstanding work requests. limit is " << this->maxWrPerSend;
|
||||
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
@@ -269,7 +269,7 @@ void IbQp::postRecv(uint64_t wrId) {
|
||||
}
|
||||
}
|
||||
|
||||
int IbQp::pollCq() { return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs.get()); }
|
||||
int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); }
|
||||
|
||||
IbQpInfo& IbQp::getInfo() { return this->info; }
|
||||
|
||||
@@ -335,7 +335,8 @@ int IbCtx::getAnyActivePort() const {
|
||||
return -1;
|
||||
}
|
||||
|
||||
IbQp* IbCtx::createQp(int port /*=-1*/) {
|
||||
IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
|
||||
int port /*=-1*/) {
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
@@ -344,7 +345,7 @@ IbQp* IbCtx::createQp(int port /*=-1*/) {
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError);
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend));
|
||||
return qps.back().get();
|
||||
}
|
||||
|
||||
|
||||
@@ -4,9 +4,6 @@
|
||||
#ifndef MSCCLPP_CONNECTION_HPP_
|
||||
#define MSCCLPP_CONNECTION_HPP_
|
||||
|
||||
// TODO(saemal): make this configurable
|
||||
#define MSCCLPP_POLLING_WAIT 3e7 // in microseconds
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
@@ -46,7 +43,7 @@ class CudaIpcConnection : public ConnectionBase {
|
||||
uint64_t size) override;
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
|
||||
|
||||
void flush() override;
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
};
|
||||
|
||||
class IBConnection : public ConnectionBase {
|
||||
@@ -59,7 +56,8 @@ class IBConnection : public ConnectionBase {
|
||||
mscclpp::TransportInfo dstTransportInfo_;
|
||||
|
||||
public:
|
||||
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
|
||||
IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr,
|
||||
int maxWrPerSend, Communicator::Impl& commImpl);
|
||||
|
||||
Transport transport() override;
|
||||
|
||||
@@ -69,7 +67,7 @@ class IBConnection : public ConnectionBase {
|
||||
uint64_t size) override;
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
|
||||
|
||||
void flush() override;
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
|
||||
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override;
|
||||
|
||||
|
||||
@@ -8,11 +8,6 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#define MSCCLPP_IB_CQ_SIZE 1024
|
||||
#define MSCCLPP_IB_CQ_POLL_NUM 1
|
||||
#define MSCCLPP_IB_MAX_SENDS 64
|
||||
#define MSCCLPP_IB_MAX_DEVS 8
|
||||
|
||||
// Forward declarations of IB structures
|
||||
struct ibv_context;
|
||||
struct ibv_pd;
|
||||
@@ -84,7 +79,8 @@ class IbQp {
|
||||
ibv_sge* sge;
|
||||
};
|
||||
|
||||
IbQp(ibv_context* ctx, ibv_pd* pd, int port);
|
||||
IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
|
||||
int maxWrPerSend);
|
||||
WrInfo getNewWrInfo();
|
||||
|
||||
IbQpInfo info;
|
||||
@@ -96,6 +92,9 @@ class IbQp {
|
||||
std::unique_ptr<ibv_sge[]> sges;
|
||||
int wrn;
|
||||
|
||||
const int maxCqPollNum;
|
||||
const int maxWrPerSend;
|
||||
|
||||
friend class IbCtx;
|
||||
};
|
||||
|
||||
@@ -104,7 +103,7 @@ class IbCtx {
|
||||
IbCtx(const std::string& devName);
|
||||
~IbCtx();
|
||||
|
||||
IbQp* createQp(int port = -1);
|
||||
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1);
|
||||
const IbMr* registerMr(void* buff, std::size_t size);
|
||||
|
||||
const std::string& getDevName() const;
|
||||
@@ -122,4 +121,4 @@ class IbCtx {
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include <fstream>
|
||||
#include <mscclpp/cuda_utils.hpp>
|
||||
|
||||
#include "api.h"
|
||||
|
||||
// Convert a logical cudaDev index to the NVML device minor number
|
||||
static const std::string getBusId(int cudaDev) {
|
||||
// On most systems, the PCI bus ID comes back as in the 0000:00:00.0
|
||||
@@ -22,7 +24,7 @@ static const std::string getBusId(int cudaDev) {
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
int getDeviceNumaNode(int cudaDev) {
|
||||
MSCCLPP_API_CPP int getDeviceNumaNode(int cudaDev) {
|
||||
std::string busId = getBusId(cudaDev);
|
||||
std::string file_str = "/sys/bus/pci/devices/" + busId + "/numa_node";
|
||||
std::ifstream file(file_str);
|
||||
@@ -37,7 +39,7 @@ int getDeviceNumaNode(int cudaDev) {
|
||||
return numaNode;
|
||||
}
|
||||
|
||||
void numaBind(int node) {
|
||||
MSCCLPP_API_CPP void numaBind(int node) {
|
||||
int totalNumNumaNodes = numa_num_configured_nodes();
|
||||
if (node < 0 || node >= totalNumNumaNodes) {
|
||||
throw Error(
|
||||
|
||||
25
src/proxy.cc
25
src/proxy.cc
@@ -15,14 +15,13 @@ namespace mscclpp {
|
||||
const int ProxyStopCheckPeriod = 1000;
|
||||
|
||||
// Unless explicitly requested, a flush of the tail to device memory is triggered for every ProxyFlushPeriod.
|
||||
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
|
||||
// As long as the FIFO size is large enough, having a stale tail is not a problem.
|
||||
const int ProxyFlushPeriod = 4;
|
||||
static_assert(MSCCLPP_PROXY_FIFO_SIZE >= ProxyFlushPeriod, "MSCCLPP_PROXY_FIFO_SIZE is too small");
|
||||
|
||||
struct Proxy::Impl {
|
||||
ProxyHandler handler;
|
||||
std::function<void()> threadInit;
|
||||
HostProxyFifo fifo;
|
||||
Fifo fifo;
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
@@ -53,10 +52,12 @@ MSCCLPP_API_CPP void Proxy::start() {
|
||||
pimpl->threadInit();
|
||||
|
||||
ProxyHandler handler = this->pimpl->handler;
|
||||
HostProxyFifo& fifo = this->pimpl->fifo;
|
||||
Fifo& fifo = this->pimpl->fifo;
|
||||
std::atomic_bool& running = this->pimpl->running;
|
||||
ProxyTrigger trigger;
|
||||
|
||||
int flushPeriod = std::min(fifo.size(), ProxyFlushPeriod);
|
||||
|
||||
int runCnt = ProxyStopCheckPeriod;
|
||||
uint64_t flushCnt = 0;
|
||||
for (;;) {
|
||||
@@ -67,19 +68,19 @@ MSCCLPP_API_CPP void Proxy::start() {
|
||||
}
|
||||
}
|
||||
// Poll to see if we are ready to send anything
|
||||
fifo.poll(&trigger);
|
||||
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
continue; // there is one in progress
|
||||
trigger = fifo.poll();
|
||||
if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
continue; // there is one in progress
|
||||
}
|
||||
trigger.snd ^= ((uint64_t)1 << (uint64_t)63); // this is where the last bit of snd is reverted.
|
||||
|
||||
ProxyHandlerResult result = handler(trigger);
|
||||
|
||||
// Send completion: reset only the high 64 bits
|
||||
fifo.pop();
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
|
||||
// Flush the tail to device memory. This is either triggered every flushPeriod to make sure that the fifo can make
|
||||
// progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
|
||||
if ((++flushCnt % flushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
|
||||
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
|
||||
fifo.flushTail();
|
||||
}
|
||||
@@ -107,6 +108,6 @@ MSCCLPP_API_CPP void Proxy::stop() {
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() { return pimpl->fifo; }
|
||||
MSCCLPP_API_CPP Fifo& Proxy::fifo() { return pimpl->fifo; }
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -1,31 +1,36 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/proxy_channel.hpp>
|
||||
|
||||
#include "api.h"
|
||||
#include "debug.h"
|
||||
#include "numa.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore,
|
||||
DeviceProxyFifo fifo)
|
||||
FifoDeviceHandle fifo)
|
||||
: semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {}
|
||||
|
||||
MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src)
|
||||
: proxyChan_(proxyChan), dst_(dst), src_(src) {}
|
||||
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService(Communicator& communicator)
|
||||
: communicator_(communicator),
|
||||
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService()
|
||||
: proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Connection> connection) {
|
||||
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator_, connection));
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection) {
|
||||
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator, connection));
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore) {
|
||||
semaphores_.push_back(semaphore);
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
|
||||
@@ -38,8 +43,8 @@ MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(Se
|
||||
return semaphores_[id];
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP ProxyChannel ProxyService::deviceChannel(SemaphoreId id) {
|
||||
return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceFifo());
|
||||
MSCCLPP_API_CPP ProxyChannel ProxyService::proxyChannel(SemaphoreId id) {
|
||||
return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceHandle());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_.start(); }
|
||||
@@ -78,14 +83,12 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
DeviceHandle<ProxyChannel> deviceHandle(ProxyChannel&& proxyChannel) {
|
||||
return proxyChannel;
|
||||
MSCCLPP_API_CPP ProxyChannel::DeviceHandle ProxyChannel::deviceHandle() const {
|
||||
return ProxyChannel::DeviceHandle{.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_, .fifo_ = fifo_};
|
||||
}
|
||||
|
||||
template <>
|
||||
DeviceHandle<SimpleProxyChannel> deviceHandle(SimpleProxyChannel&& simpleProxyChannel) {
|
||||
return simpleProxyChannel;
|
||||
MSCCLPP_API_CPP SimpleProxyChannel::DeviceHandle SimpleProxyChannel::deviceHandle() const {
|
||||
return SimpleProxyChannel::DeviceHandle{.proxyChan_ = proxyChan_.deviceHandle(), .dst_ = dst_, .src_ = src_};
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -207,8 +207,8 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz
|
||||
CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm,
|
||||
mscclpp::ProxyService& channelService, int* data_d, size_t dataSize) {
|
||||
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::ProxyService& proxyService,
|
||||
int* data_d, size_t dataSize) {
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
@@ -226,7 +226,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
semaphoreIds.push_back(channelService.addSemaphore(comm.connectOnSetup(r, 0, transport)));
|
||||
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, comm.connectOnSetup(r, 0, transport)));
|
||||
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
localMemories.push_back(memory);
|
||||
comm.sendMemoryOnSetup(memory, r, 0);
|
||||
@@ -238,8 +238,8 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
|
||||
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
|
||||
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
|
||||
channelService.deviceChannel(semaphoreIds[i]), channelService.addMemory(remoteMemories[i].get()),
|
||||
channelService.addMemory(localMemories[i]))));
|
||||
proxyService.proxyChannel(semaphoreIds[i]), proxyService.addMemory(remoteMemories[i].get()),
|
||||
proxyService.addMemory(localMemories[i]))));
|
||||
}
|
||||
|
||||
assert(proxyChannels.size() < sizeof(constProxyChans) / sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>));
|
||||
@@ -396,16 +396,16 @@ int main(int argc, const char* argv[]) {
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
|
||||
bootstrap->initialize(ip_port);
|
||||
mscclpp::Communicator comm(bootstrap);
|
||||
mscclpp::ProxyService channelService(comm);
|
||||
mscclpp::ProxyService proxyService;
|
||||
|
||||
if (rank == 0) printf("Initializing data for allgather test\n");
|
||||
initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d);
|
||||
|
||||
if (rank == 0) printf("Setting up the connection in MSCCL++\n");
|
||||
setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize);
|
||||
setupMscclppConnections(rank, world_size, comm, proxyService, data_d, dataSize);
|
||||
|
||||
if (rank == 0) printf("Launching MSCCL++ proxy threads\n");
|
||||
channelService.startProxy();
|
||||
proxyService.startProxy();
|
||||
|
||||
if (rank == 0) printf("Testing the correctness of AllGather implementation\n");
|
||||
cudaStream_t stream;
|
||||
@@ -480,7 +480,7 @@ int main(int argc, const char* argv[]) {
|
||||
bootstrap->allGather(tmp, sizeof(int));
|
||||
|
||||
if (rank == 0) printf("Stopping MSCCL++ proxy threads\n");
|
||||
channelService.stopProxy();
|
||||
proxyService.stopProxy();
|
||||
|
||||
} catch (std::exception& e) {
|
||||
// todo: throw exceptions in the implementation and process them here
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/cuda_utils.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
#include <numa.hpp>
|
||||
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
#include "mpi.h"
|
||||
@@ -45,7 +45,7 @@ static double getTime(void) {
|
||||
return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec;
|
||||
}
|
||||
|
||||
__global__ void kernel(int r, int nranks, mscclpp::DeviceProxyFifo fifo,
|
||||
__global__ void kernel(int r, int nranks, mscclpp::FifoDeviceHandle fifo,
|
||||
mscclpp::Host2DeviceSemaphore::DeviceHandle* handles, int handleIndex) {
|
||||
int tid = threadIdx.x;
|
||||
__syncthreads();
|
||||
@@ -188,7 +188,7 @@ class MyProxyService {
|
||||
|
||||
void stop() { proxy_.stop(); }
|
||||
|
||||
mscclpp::HostProxyFifo& fifo() { return proxy_.fifo(); }
|
||||
mscclpp::Fifo& fifo() { return proxy_.fifo(); }
|
||||
|
||||
mscclpp::Host2DeviceSemaphore::DeviceHandle getDeviceHandle1(int r) { return deviceSemaphores1_[r]->deviceHandle(); }
|
||||
|
||||
@@ -261,7 +261,7 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
if (rank == 0) printf("Launching MSCCL++ proxy threads\n");
|
||||
proxyService.start();
|
||||
mscclpp::DeviceProxyFifo fifo = proxyService.fifo().deviceFifo();
|
||||
mscclpp::FifoDeviceHandle fifo = proxyService.fifo().deviceHandle();
|
||||
if (rank == 0) printf("Testing the correctness of AllGather implementation\n");
|
||||
cudaStream_t stream;
|
||||
CUCHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <mscclpp/config.hpp>
|
||||
|
||||
#include "mp_unit_tests.hpp"
|
||||
|
||||
void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
|
||||
@@ -88,10 +86,6 @@ TEST_F(BootstrapTest, ExitBeforeConnect) {
|
||||
}
|
||||
|
||||
TEST_F(BootstrapTest, TimeoutWithId) {
|
||||
// Set bootstrap timeout to 1 second
|
||||
mscclpp::Config* cfg = mscclpp::Config::getInstance();
|
||||
cfg->setBootstrapConnectionTimeoutConfig(1);
|
||||
|
||||
mscclpp::Timer timer;
|
||||
|
||||
// All ranks initialize a bootstrap with their own id (will hang)
|
||||
@@ -99,7 +93,8 @@ TEST_F(BootstrapTest, TimeoutWithId) {
|
||||
mscclpp::UniqueId id = bootstrap->createUniqueId();
|
||||
|
||||
try {
|
||||
bootstrap->initialize(id);
|
||||
// Set bootstrap timeout to 1 second
|
||||
bootstrap->initialize(id, 1);
|
||||
} catch (const mscclpp::Error& e) {
|
||||
ASSERT_EQ(e.getErrorCode(), mscclpp::ErrorCode::Timeout);
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() {
|
||||
bootstrap->initialize(id);
|
||||
|
||||
ibCtx = std::make_shared<mscclpp::IbCtx>(ibDevName);
|
||||
qp = ibCtx->createQp();
|
||||
qp = ibCtx->createQp(1024, 1, 8192, 0, 64);
|
||||
|
||||
qpInfo[gEnv->rank] = qp->getInfo();
|
||||
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
|
||||
|
||||
@@ -124,6 +124,9 @@ class CommunicatorTest : public CommunicatorTestBase {
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remoteMemory;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
|
||||
class ProxyChannelOneToOneTest : public CommunicatorTestBase {
|
||||
protected:
|
||||
void SetUp() override;
|
||||
@@ -134,7 +137,7 @@ class ProxyChannelOneToOneTest : public CommunicatorTestBase {
|
||||
void testPacketPingPong(bool useIbOnly);
|
||||
void testPacketPingPongPerf(bool useIbOnly);
|
||||
|
||||
std::shared_ptr<mscclpp::ProxyService> channelService;
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService;
|
||||
};
|
||||
|
||||
class SmChannelOneToOneTest : public CommunicatorTestBase {
|
||||
|
||||
@@ -5,21 +5,18 @@
|
||||
|
||||
#include "mp_unit_tests.hpp"
|
||||
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
|
||||
void ProxyChannelOneToOneTest::SetUp() {
|
||||
// Use only two ranks
|
||||
setNumRanksToUse(2);
|
||||
CommunicatorTestBase::SetUp();
|
||||
channelService = std::make_shared<mscclpp::ProxyService>(*communicator.get());
|
||||
proxyService = std::make_shared<mscclpp::ProxyService>();
|
||||
}
|
||||
|
||||
void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
|
||||
|
||||
void ProxyChannelOneToOneTest::setupMeshConnections(
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels, bool useIbOnly, void* sendBuff,
|
||||
size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) {
|
||||
void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels,
|
||||
bool useIbOnly, void* sendBuff, size_t sendBuffBytes,
|
||||
void* recvBuff, size_t recvBuffBytes) {
|
||||
const int rank = communicator->bootstrap()->getRank();
|
||||
const int worldSize = communicator->bootstrap()->getNranks();
|
||||
const bool isInPlace = (recvBuff == nullptr);
|
||||
@@ -52,12 +49,11 @@ void ProxyChannelOneToOneTest::setupMeshConnections(
|
||||
|
||||
communicator->setup();
|
||||
|
||||
mscclpp::SemaphoreId cid = channelService->addSemaphore(conn);
|
||||
mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, conn);
|
||||
communicator->setup();
|
||||
|
||||
proxyChannels.emplace_back(mscclpp::deviceHandle(
|
||||
mscclpp::SimpleProxyChannel(channelService->deviceChannel(cid), channelService->addMemory(remoteMemory.get()),
|
||||
channelService->addMemory(sendBufRegMem))));
|
||||
proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemory.get()),
|
||||
proxyService->addMemory(sendBufRegMem));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,15 +117,18 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
|
||||
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
|
||||
setupMeshConnections(proxyChannels, true, buff.get(), nElem * sizeof(int));
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
|
||||
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
|
||||
|
||||
ASSERT_EQ(proxyChannels.size(), 1);
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(),
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)));
|
||||
|
||||
channelService->startProxy();
|
||||
proxyService->startProxy();
|
||||
|
||||
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
|
||||
|
||||
@@ -153,7 +152,7 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
channelService->stopProxy();
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
|
||||
__device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer;
|
||||
@@ -227,7 +226,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
|
||||
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
|
||||
|
||||
const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
@@ -238,13 +237,19 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
|
||||
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
|
||||
|
||||
ASSERT_EQ(proxyChannels.size(), 1);
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(),
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
|
||||
for (auto& proxyChannel : proxyChannels) {
|
||||
proxyChannelHandles.push_back(proxyChannel.deviceHandle());
|
||||
}
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)));
|
||||
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
|
||||
|
||||
channelService->startProxy();
|
||||
proxyService->startProxy();
|
||||
|
||||
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
|
||||
|
||||
@@ -280,7 +285,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
|
||||
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
channelService->stopProxy();
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
|
||||
void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
|
||||
@@ -288,7 +293,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
|
||||
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocSharedCuda<int>(nElem);
|
||||
|
||||
const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
@@ -299,13 +304,19 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
|
||||
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
|
||||
|
||||
ASSERT_EQ(proxyChannels.size(), 1);
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(),
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
|
||||
for (auto& proxyChannel : proxyChannels) {
|
||||
proxyChannelHandles.push_back(proxyChannel.deviceHandle());
|
||||
}
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)));
|
||||
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
|
||||
|
||||
channelService->startProxy();
|
||||
proxyService->startProxy();
|
||||
|
||||
auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info();
|
||||
const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name());
|
||||
@@ -330,7 +341,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
|
||||
std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)nTries << " us/iter\n";
|
||||
}
|
||||
|
||||
channelService->stopProxy();
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PacketPingPong) { testPacketPingPong(false); }
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
|
||||
#include "mp_unit_tests.hpp"
|
||||
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
void SmChannelOneToOneTest::SetUp() {
|
||||
// Need at least two ranks within a node
|
||||
if (gEnv->nRanksPerNode < 2) {
|
||||
|
||||
@@ -210,6 +210,7 @@ __global__ void allgather3(int rank, int worldSize) {
|
||||
if (tid == 0) {
|
||||
mscclpp::ProxyTrigger trigger;
|
||||
trigger.fst = MAGIC;
|
||||
trigger.snd = 0;
|
||||
// offload all the work to the proxy
|
||||
uint64_t currentFifoHead = proxyChan.fifo_.push(trigger);
|
||||
// wait for the work to be done in cpu side
|
||||
@@ -278,23 +279,24 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
|
||||
nBlocksForLocalAllGather);
|
||||
}
|
||||
|
||||
class AllGatherChannelService : public mscclpp::BaseProxyService {
|
||||
class AllGatherProxyService : public mscclpp::BaseProxyService {
|
||||
public:
|
||||
AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank, int cudaDevice);
|
||||
AllGatherProxyService(int worldSize, int rank, int cudaDevice);
|
||||
void startProxy() override { proxy_.start(); }
|
||||
void stopProxy() override { proxy_.stop(); }
|
||||
void setSendBytes(size_t sendBytes) { this->sendBytes_ = sendBytes; }
|
||||
void addRemoteMemory(mscclpp::RegisteredMemory memory) { remoteMemories_.push_back(memory); }
|
||||
void setLocalMemory(mscclpp::RegisteredMemory memory) { localMemory_ = memory; }
|
||||
mscclpp::SemaphoreId addSemaphore(std::shared_ptr<mscclpp::Connection> connection) {
|
||||
semaphores_.push_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(communicator_, connection));
|
||||
mscclpp::SemaphoreId buildAndAddSemaphore(mscclpp::Communicator& communicator,
|
||||
std::shared_ptr<mscclpp::Connection> connection) {
|
||||
semaphores_.push_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(communicator, connection));
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
std::vector<DeviceHandle<mscclpp::ProxyChannel>> deviceChannels() {
|
||||
std::vector<DeviceHandle<mscclpp::ProxyChannel>> proxyChannels() {
|
||||
std::vector<DeviceHandle<mscclpp::ProxyChannel>> result;
|
||||
for (auto& semaphore : semaphores_) {
|
||||
result.push_back(
|
||||
mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceFifo())));
|
||||
mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceHandle())));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -306,7 +308,6 @@ class AllGatherChannelService : public mscclpp::BaseProxyService {
|
||||
size_t sendBytes_;
|
||||
|
||||
mscclpp::Proxy proxy_;
|
||||
mscclpp::Communicator& communicator_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<mscclpp::RegisteredMemory> remoteMemories_;
|
||||
mscclpp::RegisteredMemory localMemory_;
|
||||
@@ -314,10 +315,8 @@ class AllGatherChannelService : public mscclpp::BaseProxyService {
|
||||
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger triggerRaw);
|
||||
};
|
||||
|
||||
AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank,
|
||||
int cudaDevice)
|
||||
: communicator_(communicator),
|
||||
worldSize_(worldSize),
|
||||
AllGatherProxyService::AllGatherProxyService(int worldSize, int rank, int cudaDevice)
|
||||
: worldSize_(worldSize),
|
||||
sendBytes_(0),
|
||||
rank_(rank),
|
||||
cudaDevice_(cudaDevice),
|
||||
@@ -327,7 +326,7 @@ AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communic
|
||||
numaBind(deviceNumaNode);
|
||||
}) {}
|
||||
|
||||
mscclpp::ProxyHandlerResult AllGatherChannelService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
|
||||
mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
|
||||
size_t offset = rank_ * sendBytes_;
|
||||
if (triggerRaw.fst != MAGIC) {
|
||||
// this is not a valid trigger
|
||||
@@ -432,7 +431,7 @@ void AllGatherTestColl::setupCollTest(size_t size) {
|
||||
paramCount_ = base;
|
||||
expectedCount_ = recvCount_;
|
||||
if (isUsingHostOffload(kernelNum_)) {
|
||||
auto service = std::dynamic_pointer_cast<AllGatherChannelService>(chanService_);
|
||||
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
|
||||
service->setSendBytes(sendCount_ * typeSize_);
|
||||
}
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
@@ -459,7 +458,7 @@ class AllGatherTestEngine : public BaseTestEngine {
|
||||
std::vector<void*> getSendBuff() override;
|
||||
void* getRecvBuff() override;
|
||||
void* getScratchBuff() override;
|
||||
std::shared_ptr<mscclpp::BaseProxyService> createChannelService() override;
|
||||
std::shared_ptr<mscclpp::BaseProxyService> createProxyService() override;
|
||||
|
||||
private:
|
||||
void* getExpectedBuff() override;
|
||||
@@ -492,31 +491,31 @@ void AllGatherTestEngine::setupConnections() {
|
||||
CUDATHROW(cudaMemcpyToSymbol(constSmChans, smChannelHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelHandles.size()));
|
||||
} else {
|
||||
auto service = std::dynamic_pointer_cast<AllGatherChannelService>(chanService_);
|
||||
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
|
||||
setupMeshConnections(devProxyChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
|
||||
[&](std::vector<std::shared_ptr<mscclpp::Connection>> conns,
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteMemories,
|
||||
const mscclpp::RegisteredMemory& localMemory) {
|
||||
std::vector<mscclpp::SemaphoreId> semaphoreIds;
|
||||
for (size_t i = 0; i < conns.size(); ++i) {
|
||||
service->addSemaphore(conns[i]);
|
||||
service->buildAndAddSemaphore(*comm_, conns[i]);
|
||||
service->addRemoteMemory(remoteMemories[i].get());
|
||||
}
|
||||
service->setLocalMemory(localMemory);
|
||||
comm_->setup();
|
||||
});
|
||||
auto proxyChannels = service->deviceChannels();
|
||||
auto proxyChannels = service->proxyChannels();
|
||||
assert(proxyChannels.size() < sizeof(constRawProxyChan) / sizeof(DeviceHandle<mscclpp::ProxyChannel>));
|
||||
CUDATHROW(cudaMemcpyToSymbol(constRawProxyChan, proxyChannels.data(),
|
||||
sizeof(DeviceHandle<mscclpp::ProxyChannel>) * proxyChannels.size()));
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::BaseProxyService> AllGatherTestEngine::createChannelService() {
|
||||
std::shared_ptr<mscclpp::BaseProxyService> AllGatherTestEngine::createProxyService() {
|
||||
if (isUsingHostOffload(args_.kernelNum)) {
|
||||
return std::make_shared<AllGatherChannelService>(*comm_, args_.totalRanks, args_.rank, args_.gpuNum);
|
||||
return std::make_shared<AllGatherProxyService>(args_.totalRanks, args_.rank, args_.gpuNum);
|
||||
} else {
|
||||
return std::make_shared<mscclpp::ProxyService>(*comm_);
|
||||
return std::make_shared<mscclpp::ProxyService>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ __device__ void localReduceScatter(int* buff, int* scratch, int rank, int nRanks
|
||||
int prePeerRecvId = (preRemoteRecvFromRank < rank) ? preRemoteRecvFromRank : preRemoteRecvFromRank - 1;
|
||||
|
||||
// overlap communication and computation
|
||||
mscclpp::SimpleProxyChannel& preDevFstRecvChan = constDevFstRoundChans[prePeerRecvId];
|
||||
DeviceHandle<mscclpp::SimpleProxyChannel>& preDevFstRecvChan = constDevFstRoundChans[prePeerRecvId];
|
||||
if (isComm) {
|
||||
preDevFstRecvChan.wait();
|
||||
devFstSendChan.putWithSignal(dstOffset, srcOffset, nelems * sizeof(int));
|
||||
@@ -563,7 +563,8 @@ __global__ void allreduce0(int* buff, int* scratch, int rank, int worldSize, siz
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
|
||||
__global__ void __launch_bounds__(1024)
|
||||
allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
|
||||
int isComm = (threadIdx.x == 0) && (blockIdx.x == 0);
|
||||
int remoteSendRank = (rank + 1) % worldSize;
|
||||
int remoteRecvRank = (rank + worldSize - 1) % worldSize;
|
||||
@@ -686,7 +687,7 @@ __global__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getP
|
||||
|
||||
// Channel to a remote peer that has the same local rank as me
|
||||
int localRank = rank % nRanksPerNode;
|
||||
mscclpp::SimpleProxyChannel proxyChan = constDevFstRoundChans[localRank];
|
||||
DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan = constDevFstRoundChans[localRank];
|
||||
|
||||
// Flag for packets. Initially 1
|
||||
uint32_t flag = (uint32_t)globalFlag;
|
||||
@@ -779,8 +780,8 @@ __global__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getP
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems) {
|
||||
__global__ void __launch_bounds__(1024)
|
||||
allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
|
||||
reduceScatter(buff, scratch, rank, nRanksPerNode, worldSize, nelems);
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
allGather(rank, worldSize, nRanksPerNode, nelems / worldSize);
|
||||
|
||||
@@ -16,7 +16,7 @@ void* localSendBuff;
|
||||
__device__ void localAlltoall(int rank, int nRanksPerNode, size_t nElements) {
|
||||
int remoteRank = (blockIdx.x < rank) ? blockIdx.x : blockIdx.x + 1;
|
||||
for (int i = 1; i < nRanksPerNode; i++) {
|
||||
mscclpp::SimpleProxyChannel proxyChan = constProxyChans[blockIdx.x];
|
||||
DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan = constProxyChans[blockIdx.x];
|
||||
if (threadIdx.x == 0 && remoteRank % nRanksPerNode == (rank + i) % nRanksPerNode) {
|
||||
proxyChan.putWithSignalAndFlush(rank * nElements * sizeof(int), remoteRank * nElements * sizeof(int),
|
||||
nElements * sizeof(int));
|
||||
|
||||
@@ -16,9 +16,17 @@ def load_perf_file(perf_fine: str) -> dict:
|
||||
"time": data["time"],
|
||||
}
|
||||
if "target" in data:
|
||||
res[(data["name"], data["kernel"], data["ranks"], data["ranksPerNode"], data["size"])]["target"] = data[
|
||||
res[
|
||||
(
|
||||
data["name"],
|
||||
data["kernel"],
|
||||
data["ranks"],
|
||||
data["ranksPerNode"],
|
||||
data["size"],
|
||||
)
|
||||
][
|
||||
"target"
|
||||
]
|
||||
] = data["target"]
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@@ -335,7 +335,7 @@ void BaseTestEngine::bootstrap() {
|
||||
}
|
||||
|
||||
void BaseTestEngine::setupTest() {
|
||||
this->chanService_ = this->createChannelService();
|
||||
this->chanService_ = this->createProxyService();
|
||||
this->setupConnections();
|
||||
this->chanService_->startProxy();
|
||||
this->coll_->setChanService(this->chanService_);
|
||||
@@ -357,8 +357,8 @@ size_t BaseTestEngine::checkData() {
|
||||
return nErrors;
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createChannelService() {
|
||||
return std::make_shared<mscclpp::ProxyService>(*comm_);
|
||||
std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createProxyService() {
|
||||
return std::make_shared<mscclpp::ProxyService>();
|
||||
}
|
||||
|
||||
void BaseTestEngine::setupMeshConnectionsInternal(
|
||||
@@ -416,8 +416,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp
|
||||
auto service = std::dynamic_pointer_cast<mscclpp::ProxyService>(chanService_);
|
||||
for (size_t i = 0; i < connections.size(); ++i) {
|
||||
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
|
||||
service->deviceChannel(service->addSemaphore(connections[i])), service->addMemory(remoteRegMemories[i].get()),
|
||||
service->addMemory(inputBufRegMem))));
|
||||
service->proxyChannel(service->buildAndAddSemaphore(*comm_, connections[i])),
|
||||
service->addMemory(remoteRegMemories[i].get()), service->addMemory(inputBufRegMem))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,7 +498,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
smSemaphores.emplace(cid, std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, connections[cid]));
|
||||
} else {
|
||||
connIdToSemId[cid] = service->addSemaphore(connections[cid]);
|
||||
connIdToSemId[cid] = service->buildAndAddSemaphore(*comm_, connections[cid]);
|
||||
}
|
||||
}
|
||||
comm_->setup();
|
||||
@@ -513,7 +513,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
|
||||
throw std::runtime_error("IB transport requires putPacketBuff and getPacketBuff");
|
||||
}
|
||||
proxyChannels.emplace_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
|
||||
service->deviceChannel(connIdToSemId[cid]), service->addMemory(remoteRegMemories[cid].get()),
|
||||
service->proxyChannel(connIdToSemId[cid]), service->addMemory(remoteRegMemories[cid].get()),
|
||||
service->addMemory(putPacketBufRegMem))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ class BaseTestEngine {
|
||||
|
||||
private:
|
||||
virtual void setupConnections() = 0;
|
||||
virtual std::shared_ptr<mscclpp::BaseProxyService> createChannelService();
|
||||
virtual std::shared_ptr<mscclpp::BaseProxyService> createProxyService();
|
||||
virtual void* getExpectedBuff() = 0;
|
||||
|
||||
double benchTime();
|
||||
|
||||
@@ -5,40 +5,40 @@
|
||||
|
||||
#include <mscclpp/cuda_utils.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
#include "numa.hpp"
|
||||
#define ITER 10000 // should be larger than the FIFO size for proper testing
|
||||
|
||||
#define FLUSH_PERIOD (MSCCLPP_PROXY_FIFO_SIZE) // should not exceed MSCCLPP_PROXY_FIFO_SIZE
|
||||
#define ITER 10000 // should be larger than MSCCLPP_PROXY_FIFO_SIZE for proper testing
|
||||
|
||||
__constant__ mscclpp::DeviceProxyFifo gFifoTestDeviceProxyFifo;
|
||||
__constant__ mscclpp::FifoDeviceHandle gFifoTestFifoDeviceHandle;
|
||||
__global__ void kernelFifoTest() {
|
||||
if (threadIdx.x + blockIdx.x * blockDim.x != 0) return;
|
||||
|
||||
mscclpp::DeviceProxyFifo& fifo = gFifoTestDeviceProxyFifo;
|
||||
mscclpp::FifoDeviceHandle& fifo = gFifoTestFifoDeviceHandle;
|
||||
mscclpp::ProxyTrigger trigger;
|
||||
for (uint64_t i = 1; i < ITER + 1; ++i) {
|
||||
trigger.fst = i;
|
||||
trigger.snd = i;
|
||||
uint64_t curFifoHead = fifo.push(trigger);
|
||||
if (i % FLUSH_PERIOD == 0) {
|
||||
if (i % fifo.size == 0) {
|
||||
fifo.sync(curFifoHead);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FifoTest, HostProxyFifo) {
|
||||
ASSERT_LE(FLUSH_PERIOD, MSCCLPP_PROXY_FIFO_SIZE);
|
||||
|
||||
TEST(FifoTest, Fifo) {
|
||||
int cudaNum;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaNum));
|
||||
int numaNode = mscclpp::getDeviceNumaNode(cudaNum);
|
||||
mscclpp::numaBind(numaNode);
|
||||
|
||||
mscclpp::HostProxyFifo hostFifo;
|
||||
mscclpp::DeviceProxyFifo devFifo = hostFifo.deviceFifo();
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gFifoTestDeviceProxyFifo, &devFifo, sizeof(devFifo)));
|
||||
mscclpp::Fifo hostFifo;
|
||||
if (hostFifo.size() >= ITER) {
|
||||
FAIL() << "ITER is too small for proper testing.";
|
||||
}
|
||||
|
||||
mscclpp::FifoDeviceHandle devFifo = hostFifo.deviceHandle();
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gFifoTestFifoDeviceHandle, &devFifo, sizeof(devFifo)));
|
||||
|
||||
kernelFifoTest<<<1, 1>>>();
|
||||
MSCCLPP_CUDATHROW(cudaGetLastError());
|
||||
@@ -51,17 +51,19 @@ TEST(FifoTest, HostProxyFifo) {
|
||||
uint64_t flushCnt = 0;
|
||||
mscclpp::Timer timer(3);
|
||||
for (uint64_t i = 0; i < ITER; ++i) {
|
||||
while (trigger.fst == 0) {
|
||||
hostFifo.poll(&trigger);
|
||||
while (trigger.fst == 0 || trigger.snd == 0) {
|
||||
trigger = hostFifo.poll();
|
||||
|
||||
if (spin++ > 1000000) {
|
||||
FAIL() << "Polling is stuck.";
|
||||
}
|
||||
}
|
||||
// see `src/proxy.cc` for the reason of this line
|
||||
trigger.snd ^= ((uint64_t)1 << (uint64_t)63);
|
||||
ASSERT_TRUE(trigger.fst == (i + 1));
|
||||
ASSERT_TRUE(trigger.snd == (i + 1));
|
||||
hostFifo.pop();
|
||||
if ((++flushCnt % FLUSH_PERIOD) == 0) {
|
||||
if ((++flushCnt % hostFifo.size()) == 0) {
|
||||
hostFifo.flushTail();
|
||||
}
|
||||
trigger.fst = 0;
|
||||
@@ -70,7 +72,7 @@ TEST(FifoTest, HostProxyFifo) {
|
||||
hostFifo.flushTail(true);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "FifoTest.HostProxyFifo: " << (float)timer.elapsed() / ITER << " us/iter\n";
|
||||
ss << "FifoTest.Fifo: " << (float)timer.elapsed() / ITER << " us/iter\n";
|
||||
std::cout << ss.str();
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <mscclpp/cuda_utils.hpp>
|
||||
|
||||
#include "numa.hpp"
|
||||
#include <mscclpp/numa.hpp>
|
||||
|
||||
TEST(NumaTest, Basic) {
|
||||
int num;
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
from queue import Queue
|
||||
import os
|
||||
|
||||
|
||||
def parse_npkit_event_header(npkit_event_header_path):
|
||||
@@ -118,7 +116,10 @@ def parse_gpu_event_file(npkit_dump_dir, npkit_event_def, rank, buf_idx, gpu_clo
|
||||
)
|
||||
event_type_to_seq[event_type] += 1
|
||||
else:
|
||||
gpu_events[-1]["args"] = {"size": parsed_gpu_event["size"], "rsvd": parsed_gpu_event["rsvd"]}
|
||||
gpu_events[-1]["args"] = {
|
||||
"size": parsed_gpu_event["size"],
|
||||
"rsvd": parsed_gpu_event["rsvd"],
|
||||
}
|
||||
delta_time = gpu_events[-1]["ts"] - gpu_events[-2]["ts"]
|
||||
gpu_events[-1]["args"]["bw (GB/s)"] = gpu_events[-1]["args"]["size"] / delta_time / 1e3
|
||||
raw_content_idx += raw_event_size
|
||||
@@ -238,7 +239,12 @@ def convert_npkit_dump_to_trace(npkit_dump_dir, output_dir, npkit_event_def):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--npkit_dump_dir", type=str, required=True, help="NPKit dump directory.")
|
||||
parser.add_argument("--npkit_event_header_path", type=str, required=True, help="Path to npkit_event.h.")
|
||||
parser.add_argument(
|
||||
"--npkit_event_header_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to npkit_event.h.",
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user