mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 09:46:00 +00:00
New packet format & optimizations (#256)
Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
@@ -83,11 +83,11 @@ if(USE_CUDA)
|
||||
else()
|
||||
set(CMAKE_HIP_STANDARD 17)
|
||||
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wall -Wextra")
|
||||
project(mscclpp LANGUAGES CXX HIP)
|
||||
project(mscclpp LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_HIP_ARCHITECTURES gfx90a gfx941 gfx942)
|
||||
|
||||
set(GPU_LIBRARIES hip::host)
|
||||
set(GPU_LIBRARIES hip::device)
|
||||
set(GPU_INCLUDE_DIRS ${hip_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -27,10 +27,24 @@ CMake 3.25 or later is required.
|
||||
```bash
|
||||
$ git clone https://github.com/microsoft/mscclpp.git
|
||||
$ mkdir -p mscclpp/build && cd mscclpp/build
|
||||
```
|
||||
|
||||
For NVIDIA platforms, build MSCCL++ as follows.
|
||||
|
||||
```bash
|
||||
# For NVIDIA platforms
|
||||
$ cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
$ make -j
|
||||
```
|
||||
|
||||
For AMD platforms, use HIPCC instead of the default C++ compiler. Replace `/path/to/hipcc` from the command below into the your HIPCC path.
|
||||
|
||||
```bash
|
||||
# For AMD platforms
|
||||
$ CXX=/path/to/hipcc cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
$ make -j
|
||||
```
|
||||
|
||||
## Install from Source (Libraries and Headers)
|
||||
|
||||
```bash
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#define MSCCLPP_PACKET_DEVICE_HPP_
|
||||
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
#include "device.hpp"
|
||||
|
||||
@@ -14,9 +15,8 @@
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// LL (low latency) protocol packet.
|
||||
union alignas(16) LLPacket {
|
||||
union alignas(16) LL16Packet {
|
||||
// Assume data is written with an atomicity of 8 bytes (IB/RDMA).
|
||||
struct {
|
||||
uint32_t data1;
|
||||
@@ -28,7 +28,7 @@ union alignas(16) LLPacket {
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
ulonglong2 raw_;
|
||||
|
||||
MSCCLPP_DEVICE_INLINE LLPacket() {}
|
||||
MSCCLPP_DEVICE_INLINE LL16Packet() {}
|
||||
|
||||
/// Write 8 bytes of data to the packet.
|
||||
/// @param val1 The first 4-byte data to write.
|
||||
@@ -88,34 +88,202 @@ union alignas(16) LLPacket {
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
};
|
||||
|
||||
union alignas(8) LL8Packet {
|
||||
// Assume data is written with an atomicity of 8 bytes (IB/RDMA).
|
||||
struct {
|
||||
uint32_t data;
|
||||
uint32_t flag;
|
||||
};
|
||||
uint64_t raw_;
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
/// Read from the origin and write to the target buffer.
|
||||
MSCCLPP_DEVICE_INLINE void putPackets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
|
||||
MSCCLPP_DEVICE_INLINE LL8Packet() {}
|
||||
|
||||
MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" ::"l"(&raw_), "r"(val), "r"(flag));
|
||||
#else // !defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint2 reg = make_uint2(val, flag);
|
||||
uint64_t* p = reinterpret_cast<uint64_t*>(®);
|
||||
atomicStore(&(raw_), *p, memoryOrderRelaxed);
|
||||
#endif
|
||||
}
|
||||
|
||||
MSCCLPP_DEVICE_INLINE bool readOnce(uint32_t flag, uint32_t& data) const {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint32_t f;
|
||||
asm volatile("ld.volatile.global.v2.u32 {%0,%1}, [%2];" : "=r"(data), "=r"(f) : "l"(&raw_));
|
||||
return (f != flag);
|
||||
#else // !defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint64_t reg;
|
||||
reg = atomicLoad(&(raw_), memoryOrderRelaxed);
|
||||
uint2* ptr = reinterpret_cast<uint2*>(®);
|
||||
data = ptr->x;
|
||||
return (ptr->y != flag);
|
||||
#endif
|
||||
}
|
||||
|
||||
MSCCLPP_DEVICE_INLINE uint32_t read(uint32_t flag, int64_t maxSpinCount = 1000000) const {
|
||||
uint32_t data;
|
||||
POLL_MAYBE_JAILBREAK(readOnce(flag, data), maxSpinCount);
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Clear the packet.
|
||||
MSCCLPP_DEVICE_INLINE void clear() { raw_ = 0; }
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
};
|
||||
|
||||
using LLPacket = LL16Packet;
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
/// Read data from the origin and write LL16Packets to the target buffer.
|
||||
///
|
||||
/// @param targetPtr The target buffer.
|
||||
/// @param targetOffset The offset in the target buffer.
|
||||
/// @param originPtr The origin buffer.
|
||||
/// @param originOffset The offset in the origin buffer.
|
||||
/// @param originBytes The number of bytes to write to the target buffer.
|
||||
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
|
||||
/// @param numThreads The number of threads that call this function.
|
||||
/// @param flag The flag to write.
|
||||
///
|
||||
MSCCLPP_DEVICE_INLINE void putLL16Packets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
const uint32_t* originBase = (const uint32_t*)((const char*)originPtr + originOffset);
|
||||
LLPacket* targetBase = (LLPacket*)((char*)targetPtr + targetOffset);
|
||||
LL16Packet* targetBase = (LL16Packet*)((char*)targetPtr + targetOffset);
|
||||
size_t nElem = originBytes / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
LLPacket* pkt = &targetBase[i];
|
||||
LL16Packet* pkt = &targetBase[i];
|
||||
pkt->write(originBase[2 * i], originBase[2 * i + 1], flag);
|
||||
}
|
||||
}
|
||||
|
||||
/// Read from the target buffer and write to the origin.
|
||||
MSCCLPP_DEVICE_INLINE void getPackets(const void* targetPtr, uint64_t targetOffset, void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
/// Read LL16Packets from the target buffer and write retrieved data to the origin.
|
||||
///
|
||||
/// @param targetPtr The target buffer.
|
||||
/// @param targetOffset The offset in the target buffer.
|
||||
/// @param originPtr The origin buffer.
|
||||
/// @param originOffset The offset in the origin buffer.
|
||||
/// @param originBytes The number of bytes to write to the target buffer.
|
||||
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
|
||||
/// @param numThreads The number of threads that call this function.
|
||||
/// @param flag The flag to write.
|
||||
///
|
||||
MSCCLPP_DEVICE_INLINE void getLL16Packets(const void* targetPtr, uint64_t targetOffset, void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
const LLPacket* targetBase = (const LLPacket*)((const char*)targetPtr + targetOffset);
|
||||
const LL16Packet* targetBase = (const LL16Packet*)((const char*)targetPtr + targetOffset);
|
||||
uint2* originBase = (uint2*)((char*)originPtr + originOffset);
|
||||
size_t nElem = originBytes / sizeof(uint2);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
const LLPacket* pkt = &targetBase[i];
|
||||
const LL16Packet* pkt = &targetBase[i];
|
||||
originBase[i] = pkt->read(flag);
|
||||
}
|
||||
}
|
||||
|
||||
/// Read data from the origin and write LL8Packets to the target buffer.
|
||||
///
|
||||
/// @param targetPtr The target buffer.
|
||||
/// @param targetOffset The offset in the target buffer.
|
||||
/// @param originPtr The origin buffer.
|
||||
/// @param originOffset The offset in the origin buffer.
|
||||
/// @param originBytes The number of bytes to write to the target buffer.
|
||||
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
|
||||
/// @param numThreads The number of threads that call this function.
|
||||
/// @param flag The flag to write.
|
||||
///
|
||||
MSCCLPP_DEVICE_INLINE void putLL8Packets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes
|
||||
const uint32_t* originBase = (const uint32_t*)((const char*)originPtr + originOffset);
|
||||
LL8Packet* targetBase = (LL8Packet*)((char*)targetPtr + targetOffset);
|
||||
size_t nElem = originBytes / sizeof(uint32_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
LL8Packet* pkt = &targetBase[i];
|
||||
pkt->write(originBase[i], flag);
|
||||
}
|
||||
}
|
||||
|
||||
/// Read LL8Packets from the target buffer and write retrieved data to the origin.
|
||||
///
|
||||
/// @param targetPtr The target buffer.
|
||||
/// @param targetOffset The offset in the target buffer.
|
||||
/// @param originPtr The origin buffer.
|
||||
/// @param originOffset The offset in the origin buffer.
|
||||
/// @param originBytes The number of bytes to write to the target buffer.
|
||||
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
|
||||
/// @param numThreads The number of threads that call this function.
|
||||
/// @param flag The flag to write.
|
||||
///
|
||||
MSCCLPP_DEVICE_INLINE void getLL8Packets(const void* targetPtr, uint64_t targetOffset, void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes
|
||||
const LL8Packet* targetBase = (const LL8Packet*)((const char*)targetPtr + targetOffset);
|
||||
uint32_t* originBase = (uint32_t*)((char*)originPtr + originOffset);
|
||||
size_t nElem = originBytes / sizeof(uint32_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
const LL8Packet* pkt = &targetBase[i];
|
||||
originBase[i] = pkt->read(flag);
|
||||
}
|
||||
}
|
||||
|
||||
/// Read data from the origin and write packets to the target buffer.
|
||||
///
|
||||
/// @param targetPtr The target buffer.
|
||||
/// @param targetOffset The offset in the target buffer.
|
||||
/// @param originPtr The origin buffer.
|
||||
/// @param originOffset The offset in the origin buffer.
|
||||
/// @param originBytes The number of bytes to write to the target buffer.
|
||||
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
|
||||
/// @param numThreads The number of threads that call this function.
|
||||
/// @param flag The flag to write.
|
||||
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
|
||||
///
|
||||
template <typename PacketType = LL16Packet>
|
||||
MSCCLPP_DEVICE_INLINE void putPackets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
if constexpr (std::is_same<PacketType, LL16Packet>::value) {
|
||||
putLL16Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
|
||||
} else if constexpr (std::is_same<PacketType, LL8Packet>::value) {
|
||||
putLL8Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
|
||||
} else {
|
||||
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
|
||||
"Unsupported packet type");
|
||||
}
|
||||
}
|
||||
|
||||
/// Read packets from the target buffer and write retrieved data to the origin.
|
||||
///
|
||||
/// @param targetPtr The target buffer.
|
||||
/// @param targetOffset The offset in the target buffer.
|
||||
/// @param originPtr The origin buffer.
|
||||
/// @param originOffset The offset in the origin buffer.
|
||||
/// @param originBytes The number of bytes to read from the origin buffer.
|
||||
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
|
||||
/// @param numThreads The number of threads that call this function.
|
||||
/// @param flag The flag to read.
|
||||
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
|
||||
///
|
||||
template <typename PacketType = LL16Packet>
|
||||
MSCCLPP_DEVICE_INLINE void getPackets(const void* targetPtr, uint64_t targetOffset, void* originPtr,
|
||||
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
if constexpr (std::is_same<PacketType, LL16Packet>::value) {
|
||||
getLL16Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
|
||||
} else if constexpr (std::is_same<PacketType, LL8Packet>::value) {
|
||||
getLL8Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
|
||||
} else {
|
||||
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
|
||||
"Unsupported packet type");
|
||||
}
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
}; // namespace mscclpp
|
||||
|
||||
@@ -10,6 +10,9 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#if defined(NDEBUG)
|
||||
#define __assert_fail(__assertion, __file, __line, __function) ;
|
||||
#else // !defined(NDEBUG)
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
|
||||
const char *__function);
|
||||
@@ -17,6 +20,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__
|
||||
extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
|
||||
const char *__function) __THROW;
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
#endif // NDEBUG
|
||||
|
||||
// If a spin is stuck, print a warning and keep spinning.
|
||||
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
|
||||
|
||||
@@ -211,10 +211,12 @@ struct SmChannelDeviceHandle {
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
|
||||
///
|
||||
template <typename PacketType = LL16Packet>
|
||||
MSCCLPP_DEVICE_INLINE void putPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
|
||||
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
|
||||
mscclpp::putPackets(dst_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag);
|
||||
mscclpp::putPackets<PacketType>(dst_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag);
|
||||
}
|
||||
|
||||
/// Retrieve data from @ref LLPacket in the local packet buffer (target) and write it on the local data (origin).
|
||||
@@ -227,10 +229,13 @@ struct SmChannelDeviceHandle {
|
||||
/// @param threadId The index of the current thread among all threads running this function. This is different from
|
||||
/// the `threadIdx` in CUDA.
|
||||
/// @param numThreads The total number of threads that run this function.
|
||||
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
|
||||
///
|
||||
template <typename PacketType = LL16Packet>
|
||||
MSCCLPP_DEVICE_INLINE void getPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
|
||||
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
|
||||
mscclpp::getPackets(getPacketBuffer_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag);
|
||||
mscclpp::getPackets<PacketType>(getPacketBuffer_, targetOffset, src_, originOffset, originBytes, threadId,
|
||||
numThreads, flag);
|
||||
}
|
||||
|
||||
/// Signal the remote semaphore.
|
||||
|
||||
@@ -10,7 +10,7 @@ set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include)
|
||||
|
||||
if(USE_ROCM)
|
||||
file(GLOB_RECURSE CU_SOURCES CONFIGURE_DEPENDS *.cu)
|
||||
set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE CXX)
|
||||
endif()
|
||||
|
||||
function(add_test_executable name sources)
|
||||
|
||||
@@ -149,6 +149,8 @@ class SmChannelOneToOneTest : public CommunicatorTestBase {
|
||||
|
||||
void setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff, size_t inputBuffBytes,
|
||||
void* outputBuff = nullptr, size_t outputBuffBytes = 0);
|
||||
using PacketPingPongKernelWrapper = std::function<void(int*, int, int, int*, int)>;
|
||||
void packetPingPongTest(const std::string testName, PacketPingPongKernelWrapper kernelWrapper);
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
|
||||
};
|
||||
|
||||
@@ -70,6 +70,61 @@ void SmChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SmChannel>
|
||||
|
||||
__constant__ DeviceHandle<mscclpp::SmChannel> gChannelOneToOneTestConstSmChans;
|
||||
|
||||
void SmChannelOneToOneTest::packetPingPongTest(const std::string testName, PacketPingPongKernelWrapper kernelWrapper) {
|
||||
if (gEnv->rank >= numRanksToUse) return;
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
const int defaultNTries = 1000;
|
||||
|
||||
std::vector<mscclpp::SmChannel> smChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocExtSharedCuda<int>(nElem);
|
||||
std::shared_ptr<int> intermBuff = mscclpp::allocExtSharedCuda<int>(nElem * 2);
|
||||
setupMeshConnections(smChannels, buff.get(), nElem * sizeof(int), intermBuff.get(), nElem * 2 * sizeof(int));
|
||||
std::vector<DeviceHandle<mscclpp::SmChannel>> deviceHandles(smChannels.size());
|
||||
std::transform(smChannels.begin(), smChannels.end(), deviceHandles.begin(),
|
||||
[](const mscclpp::SmChannel& smChan) { return mscclpp::deviceHandle(smChan); });
|
||||
|
||||
ASSERT_EQ(smChannels.size(), 1);
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstSmChans, deviceHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SmChannel>)));
|
||||
|
||||
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
|
||||
|
||||
// The least nelem is 2 for packet ping pong
|
||||
kernelWrapper(buff.get(), gEnv->rank, 2, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
*ret = 0;
|
||||
|
||||
kernelWrapper(buff.get(), gEnv->rank, 1024, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelWrapper(buff.get(), gEnv->rank, 1024 * 1024, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelWrapper(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
int nTries = 1000000;
|
||||
communicator->bootstrap()->barrier();
|
||||
mscclpp::Timer timer;
|
||||
kernelWrapper(buff.get(), gEnv->rank, 1024, ret.get(), nTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
if (gEnv->rank == 0) {
|
||||
std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)(nTries) << " us/iter\n";
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernelSmPutPingPong(int* buff, int rank, int nElem, int* ret) {
|
||||
DeviceHandle<mscclpp::SmChannel>& smChan = gChannelOneToOneTestConstSmChans;
|
||||
volatile int* sendBuff = (volatile int*)buff;
|
||||
@@ -238,17 +293,53 @@ TEST_F(SmChannelOneToOneTest, GetPingPong) {
|
||||
EXPECT_EQ(*ret, 0);
|
||||
}
|
||||
|
||||
__global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret) {
|
||||
__global__ void kernelSmLL8PacketPingPong(int* buff, int rank, int nElem, int* ret, int nTries) {
|
||||
if (rank > 1) return;
|
||||
|
||||
DeviceHandle<mscclpp::SmChannel>& smChan = gChannelOneToOneTestConstSmChans;
|
||||
volatile int* sendBuff = (volatile int*)buff;
|
||||
int nTries = 1000;
|
||||
int putOffset = (rank == 0) ? 0 : 10000000;
|
||||
int getOffset = (rank == 0) ? 10000000 : 0;
|
||||
for (int i = 0; i < nTries; i++) {
|
||||
uint64_t flag = (uint64_t)i + 1;
|
||||
|
||||
// rank=0: 0, 1, 0, 1, ...
|
||||
// rank=1: 1, 0, 1, 0, ...
|
||||
if ((rank ^ (i & 1)) == 0) {
|
||||
// If each thread writes 8 bytes at once, we don't need a barrier before putPackets().
|
||||
for (int j = threadIdx.x; j < nElem; j += blockDim.x) {
|
||||
sendBuff[j] = putOffset + i + j;
|
||||
// sendBuff[2 * j + 1] = putOffset + i + 2 * j + 1;
|
||||
}
|
||||
// __syncthreads();
|
||||
smChan.putPackets<mscclpp::LL8Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
|
||||
} else {
|
||||
smChan.getPackets<mscclpp::LL8Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
|
||||
// If each thread reads 8 bytes at once, we don't need a barrier after getPackets().
|
||||
// __syncthreads();
|
||||
for (int j = threadIdx.x; j < nElem; j += blockDim.x) {
|
||||
if (sendBuff[j] != getOffset + i + j) {
|
||||
// printf("ERROR: rank = %d, sendBuff[%d] = %d, expected %d. Skipping following errors\n", rank, 2 * j,
|
||||
// sendBuff[2 * j], getOffset + i + 2 * j);
|
||||
*ret = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Make sure all threads are done in this iteration
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernelSmLL16PacketPingPong(int* buff, int rank, int nElem, int* ret, int nTries) {
|
||||
if (rank > 1) return;
|
||||
|
||||
DeviceHandle<mscclpp::SmChannel>& smChan = gChannelOneToOneTestConstSmChans;
|
||||
volatile int* sendBuff = (volatile int*)buff;
|
||||
int putOffset = (rank == 0) ? 0 : 10000000;
|
||||
int getOffset = (rank == 0) ? 10000000 : 0;
|
||||
for (int i = 0; i < nTries; i++) {
|
||||
uint64_t flag = (uint64_t)i + 1;
|
||||
// rank=0: 0, 1, 0, 1, ...
|
||||
// rank=1: 1, 0, 1, 0, ...
|
||||
if ((rank ^ (i & 1)) == 0) {
|
||||
@@ -258,9 +349,9 @@ __global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret)
|
||||
sendBuff[2 * j + 1] = putOffset + i + 2 * j + 1;
|
||||
}
|
||||
// __syncthreads();
|
||||
smChan.putPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
|
||||
smChan.putPackets<mscclpp::LL16Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
|
||||
} else {
|
||||
smChan.getPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
|
||||
smChan.getPackets<mscclpp::LL16Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
|
||||
// If each thread reads 8 bytes at once, we don't need a barrier after getPackets().
|
||||
// __syncthreads();
|
||||
for (int j = threadIdx.x; j < nElem / 2; j += blockDim.x) {
|
||||
@@ -283,46 +374,16 @@ __global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret)
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SmChannelOneToOneTest, PacketPingPong) {
|
||||
if (gEnv->rank >= numRanksToUse) return;
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<mscclpp::SmChannel> smChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocExtSharedCuda<int>(nElem);
|
||||
std::shared_ptr<int> intermBuff = mscclpp::allocExtSharedCuda<int>(nElem * 2);
|
||||
setupMeshConnections(smChannels, buff.get(), nElem * sizeof(int), intermBuff.get(), nElem * 2 * sizeof(int));
|
||||
std::vector<DeviceHandle<mscclpp::SmChannel>> deviceHandles(smChannels.size());
|
||||
std::transform(smChannels.begin(), smChannels.end(), deviceHandles.begin(),
|
||||
[](const mscclpp::SmChannel& smChan) { return mscclpp::deviceHandle(smChan); });
|
||||
|
||||
ASSERT_EQ(smChannels.size(), 1);
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstSmChans, deviceHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SmChannel>)));
|
||||
|
||||
std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);
|
||||
|
||||
// The least nelem is 2 for packet ping pong
|
||||
kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 2, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
TEST_F(SmChannelOneToOneTest, LL8PacketPingPong) {
|
||||
auto kernelSmLL8PacketPingPongWrapper = [](int* buff, int rank, int nElem, int* ret, int nTries) {
|
||||
kernelSmLL8PacketPingPong<<<1, 1024>>>(buff, rank, nElem, ret, nTries);
|
||||
};
|
||||
packetPingPongTest("smLL8PacketPingPong", kernelSmLL8PacketPingPongWrapper);
|
||||
}
|
||||
|
||||
TEST_F(SmChannelOneToOneTest, LL16PacketPingPong) {
|
||||
auto kernelSmLL16PacketPingPongWrapper = [](int* buff, int rank, int nElem, int* ret, int nTries) {
|
||||
kernelSmLL16PacketPingPong<<<1, 1024>>>(buff, rank, nElem, ret, nTries);
|
||||
};
|
||||
packetPingPongTest("smLL16PacketPingPong", kernelSmLL16PacketPingPongWrapper);
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ FetchContent_MakeAvailable(json)
|
||||
|
||||
function(add_mscclpp_test_executable name sources)
|
||||
if(USE_ROCM)
|
||||
set_source_files_properties(${sources} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${sources} PROPERTIES LANGUAGE CXX)
|
||||
endif()
|
||||
add_executable(${name} ${sources} common.cc)
|
||||
target_link_libraries(${name} ${TEST_LIBS_COMMON} MPI::MPI_CXX nlohmann_json::nlohmann_json)
|
||||
|
||||
@@ -23,7 +23,9 @@ using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> constProxyChans[16];
|
||||
__constant__ DeviceHandle<mscclpp::ProxyChannel> constRawProxyChan[16];
|
||||
|
||||
__constant__ DeviceHandle<mscclpp::SmChannel> constSmChans[8];
|
||||
__constant__ DeviceHandle<mscclpp::SmChannel> constSmChans[512];
|
||||
__constant__ DeviceHandle<mscclpp::SmChannel> constSmOutOfPlaceChans[16];
|
||||
__device__ uint64_t globalFlag;
|
||||
|
||||
__global__ void allgather0(int rank, size_t nelemsPerGPU) {
|
||||
int warpId = threadIdx.x / WARP_SIZE;
|
||||
@@ -288,6 +290,215 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
|
||||
nBlocksForLocalAllGather);
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allgather5(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
|
||||
const size_t nBlock = gridDim.x;
|
||||
if (blockIdx.x >= nBlock) return;
|
||||
|
||||
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const size_t lid = tid % WARP_SIZE;
|
||||
const size_t wid = tid / WARP_SIZE;
|
||||
|
||||
const size_t nThread = blockDim.x * nBlock;
|
||||
const size_t nWarp = nThread / WARP_SIZE;
|
||||
const size_t nPeer = nRanksPerNode - 1;
|
||||
const size_t chanOffset = nPeer * blockIdx.x;
|
||||
auto smChans = constSmChans + chanOffset;
|
||||
|
||||
if (wid < nPeer && lid == 0) {
|
||||
smChans[wid].relaxedSignal();
|
||||
smChans[wid].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
|
||||
const size_t bytes = bytesPerGPU * nPeer;
|
||||
size_t unitBytesPerThread;
|
||||
if (bytes >= nThread * 64) {
|
||||
unitBytesPerThread = 64;
|
||||
} else {
|
||||
unitBytesPerThread = 16;
|
||||
}
|
||||
const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE;
|
||||
const size_t unitBytes = unitBytesPerWarp * nWarp;
|
||||
const size_t nLoop = bytes / unitBytes;
|
||||
|
||||
if (nLoop > 0) {
|
||||
// First loop unrolling
|
||||
const size_t peerIdx = wid % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < nLoop; ++i) {
|
||||
const size_t gWid = wid + i * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
|
||||
}
|
||||
|
||||
if (bytes % unitBytes > 0) {
|
||||
const size_t gWid = wid + nLoop * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
|
||||
const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank;
|
||||
const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU)
|
||||
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
|
||||
: unitBytesPerWarp;
|
||||
if (remainBytes > 0) {
|
||||
smChans[peerIdx].get<16, true>(offset, remainBytes, lid, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allgather6(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
|
||||
const size_t nBlock = gridDim.x;
|
||||
if (blockIdx.x >= nBlock) return;
|
||||
|
||||
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const size_t lid = tid % WARP_SIZE;
|
||||
const size_t wid = tid / WARP_SIZE;
|
||||
|
||||
const size_t nThread = blockDim.x * nBlock;
|
||||
const size_t nWarp = nThread / WARP_SIZE;
|
||||
const size_t nPeer = nRanksPerNode - 1;
|
||||
const size_t chanOffset = nPeer * blockIdx.x;
|
||||
auto smChans = constSmChans + chanOffset;
|
||||
|
||||
if (wid < nPeer && lid == 0) {
|
||||
smChans[wid].relaxedSignal();
|
||||
smChans[wid].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
|
||||
const size_t bytes = bytesPerGPU * nPeer;
|
||||
size_t unitBytesPerThread;
|
||||
if (bytes >= nThread * 64) {
|
||||
unitBytesPerThread = 64;
|
||||
} else {
|
||||
unitBytesPerThread = 16;
|
||||
}
|
||||
const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE;
|
||||
const size_t unitBytes = unitBytesPerWarp * nWarp;
|
||||
const size_t nLoop = bytes / unitBytes;
|
||||
|
||||
if (nLoop > 0) {
|
||||
// First loop unrolling
|
||||
const size_t peerIdx = wid % nPeer;
|
||||
const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < nLoop; ++i) {
|
||||
const size_t gWid = wid + i * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
|
||||
}
|
||||
|
||||
if (bytes % unitBytes > 0) {
|
||||
const size_t gWid = wid + nLoop * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
|
||||
const size_t offset = bytesPerGPU * rank + offsetWithinRank;
|
||||
const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU)
|
||||
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
|
||||
: unitBytesPerWarp;
|
||||
if (remainBytes > 0) {
|
||||
smChans[peerIdx].put<16, true>(offset, remainBytes, lid, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allgather7(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
|
||||
const size_t nBlock = gridDim.x;
|
||||
if (blockIdx.x >= nBlock) return;
|
||||
|
||||
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const size_t lid = tid % WARP_SIZE;
|
||||
const size_t wid = tid / WARP_SIZE;
|
||||
|
||||
const size_t nThread = blockDim.x * nBlock;
|
||||
const size_t nWarp = nThread / WARP_SIZE;
|
||||
const size_t nPeer = nRanksPerNode - 1;
|
||||
auto smChans = constSmOutOfPlaceChans;
|
||||
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
|
||||
const size_t bytes = bytesPerGPU * nPeer;
|
||||
size_t unitBytesPerThread = 8;
|
||||
const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE;
|
||||
const size_t unitBytes = unitBytesPerWarp * nWarp;
|
||||
const size_t nLoop = bytes / unitBytes;
|
||||
|
||||
// double buffering
|
||||
const size_t scratchOffset = (flag & 1) ? 0 : bytesPerGPU * nRanksPerNode * 2;
|
||||
|
||||
if (nLoop > 0) {
|
||||
// First loop unrolling
|
||||
const size_t peerIdx = wid % nPeer;
|
||||
const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].putPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
|
||||
}
|
||||
|
||||
if (nLoop > 0) {
|
||||
// First loop unrolling
|
||||
const size_t peerIdx = wid % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < nLoop; ++i) {
|
||||
const size_t gWid = wid + i * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].putPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < nLoop; ++i) {
|
||||
const size_t gWid = wid + i * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp;
|
||||
smChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
|
||||
}
|
||||
|
||||
if (bytes % unitBytes > 0) {
|
||||
const size_t gWid = wid + nLoop * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
|
||||
const size_t offset = bytesPerGPU * rank + offsetWithinRank;
|
||||
const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU)
|
||||
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
|
||||
: unitBytesPerWarp;
|
||||
if (remainBytes > 0) {
|
||||
smChans[peerIdx].putPackets(scratchOffset + offset * 2, offset, remainBytes, lid, WARP_SIZE, flag);
|
||||
}
|
||||
}
|
||||
if (bytes % unitBytes > 0) {
|
||||
const size_t gWid = wid + nLoop * nWarp;
|
||||
const size_t peerIdx = gWid % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
|
||||
const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank;
|
||||
const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU)
|
||||
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
|
||||
: unitBytesPerWarp;
|
||||
if (remainBytes > 0) {
|
||||
smChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, remainBytes, lid, WARP_SIZE, flag);
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
class AllGatherProxyService : public mscclpp::BaseProxyService {
|
||||
public:
|
||||
AllGatherProxyService(int worldSize, int rank, int cudaDevice);
|
||||
@@ -387,6 +598,15 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
if (kernelNum == 4) {
|
||||
nBlocks = 21;
|
||||
nThreads = 1024;
|
||||
} else if (kernelNum == 5) {
|
||||
nBlocks = 24;
|
||||
nThreads = 1024;
|
||||
} else if (kernelNum == 6) {
|
||||
nBlocks = 24;
|
||||
nThreads = 1024;
|
||||
} else if (kernelNum == 7) {
|
||||
nBlocks = 4;
|
||||
nThreads = 896;
|
||||
} else {
|
||||
nBlocks = 1;
|
||||
nThreads = WARP_SIZE * (worldSize - 1);
|
||||
@@ -401,6 +621,12 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
allgather3<<<nBlocks, nThreads, 0, stream>>>();
|
||||
} else if (kernelNum == 4) {
|
||||
allgather4<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
|
||||
} else if (kernelNum == 5) {
|
||||
allgather5<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
|
||||
} else if (kernelNum == 6) {
|
||||
allgather6<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
|
||||
} else if (kernelNum == 7) {
|
||||
allgather7<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,7 +679,10 @@ std::vector<KernelRestriction> AllGatherTestColl::getKernelRestrictions() {
|
||||
{1, "allgather1", false, 1, 4 * worldSize_},
|
||||
{2, "allgather2", true, 3, 4 * worldSize_},
|
||||
{3, "allgather3", true, 1, 4 * worldSize_},
|
||||
{4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/}};
|
||||
{4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/},
|
||||
{5, "allgather5", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/},
|
||||
{6, "allgather6", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/},
|
||||
{7, "allgather7", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/}};
|
||||
}
|
||||
|
||||
class AllGatherTestEngine : public BaseTestEngine {
|
||||
@@ -474,7 +703,9 @@ class AllGatherTestEngine : public BaseTestEngine {
|
||||
|
||||
std::shared_ptr<int> sendBuff_;
|
||||
std::shared_ptr<int[]> expectedBuff_;
|
||||
std::shared_ptr<mscclpp::LLPacket> scratchPacketBuff_;
|
||||
std::vector<mscclpp::SmChannel> smChannels_;
|
||||
std::vector<mscclpp::SmChannel> smOutOfPlaceChannels_;
|
||||
};
|
||||
|
||||
AllGatherTestEngine::AllGatherTestEngine(const TestArgs& args) : BaseTestEngine(args, "allgather") {}
|
||||
@@ -482,6 +713,12 @@ AllGatherTestEngine::AllGatherTestEngine(const TestArgs& args) : BaseTestEngine(
|
||||
void AllGatherTestEngine::allocateBuffer() {
|
||||
sendBuff_ = mscclpp::allocExtSharedCuda<int>(args_.maxBytes / sizeof(int));
|
||||
expectedBuff_ = std::shared_ptr<int[]>(new int[args_.maxBytes / sizeof(int)]);
|
||||
if (args_.kernelNum == 7) {
|
||||
const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
// 2x for double-buffering, scratchBuff used to store original data and reduced results
|
||||
const size_t scratchBuffNelem = nPacket * 2 /*original data & reduced result */ * 2 /* double buffering*/;
|
||||
scratchPacketBuff_ = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(scratchBuffNelem);
|
||||
}
|
||||
}
|
||||
|
||||
void AllGatherTestEngine::setupConnections() {
|
||||
@@ -494,7 +731,7 @@ void AllGatherTestEngine::setupConnections() {
|
||||
CUDATHROW(cudaMemcpyToSymbol(constProxyChans, devProxyChannels.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>) * devProxyChannels.size()));
|
||||
|
||||
setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes);
|
||||
setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes, nullptr, 0, ChannelSemantic::PUT, 64);
|
||||
std::vector<DeviceHandle<mscclpp::SmChannel>> smChannelHandles(smChannels_.size());
|
||||
if (smChannels_.size() > sizeof(constSmChans) / sizeof(DeviceHandle<mscclpp::SmChannel>)) {
|
||||
std::runtime_error("unexpected error");
|
||||
@@ -503,6 +740,21 @@ void AllGatherTestEngine::setupConnections() {
|
||||
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
|
||||
CUDATHROW(cudaMemcpyToSymbol(constSmChans, smChannelHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelHandles.size()));
|
||||
|
||||
if (args_.kernelNum == 7) {
|
||||
const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
const size_t scratchPacketBuffBytes = nPacket * 2 * 2 * sizeof(mscclpp::LLPacket);
|
||||
setupMeshConnections(smOutOfPlaceChannels_, sendBuff_.get(), args_.maxBytes, scratchPacketBuff_.get(),
|
||||
scratchPacketBuffBytes);
|
||||
std::vector<DeviceHandle<mscclpp::SmChannel>> smOutOfPlaceChannelHandles(smOutOfPlaceChannels_.size());
|
||||
if (smOutOfPlaceChannels_.size() > sizeof(constSmOutOfPlaceChans) / sizeof(DeviceHandle<mscclpp::SmChannel>)) {
|
||||
std::runtime_error("unexpected error");
|
||||
}
|
||||
std::transform(smOutOfPlaceChannels_.begin(), smOutOfPlaceChannels_.end(), smOutOfPlaceChannelHandles.begin(),
|
||||
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
|
||||
CUDATHROW(cudaMemcpyToSymbol(constSmOutOfPlaceChans, smOutOfPlaceChannelHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SmChannel>) * smOutOfPlaceChannelHandles.size()));
|
||||
}
|
||||
} else {
|
||||
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
|
||||
setupMeshConnections(devProxyChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
|
||||
|
||||
@@ -970,9 +970,8 @@ __global__ void allreduce5(int* buff, int rank, int nRanksPerNode, int worldSize
|
||||
__global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems) {
|
||||
// This version of allreduce only works for single nodes
|
||||
if (worldSize != nRanksPerNode) return;
|
||||
const int nPeers = nRanksPerNode - 1;
|
||||
const int nPkts = nelems / 2;
|
||||
const size_t nPkts = nelems / 2;
|
||||
const int nelemsPerRank = nelems / worldSize;
|
||||
const int nPktsPerRank = nelemsPerRank / 2;
|
||||
// flag for packets. Initially 1
|
||||
@@ -982,7 +981,6 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank,
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
DeviceHandle<mscclpp::SmChannel> smChan = constSmOutOfPlaceChans[peerIdx];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
@@ -995,7 +993,8 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank,
|
||||
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
constSmOutOfPlaceChans[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid,
|
||||
blockDim.x * nBlocksPerPeer, flag);
|
||||
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
@@ -1008,11 +1007,16 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank,
|
||||
}
|
||||
data.x += src[idx].x;
|
||||
data.y += src[idx].y;
|
||||
dst[idx].x = data.x;
|
||||
dst[idx].y = data.y;
|
||||
dst[idx] = data;
|
||||
|
||||
mscclpp::LLPacket packet;
|
||||
packet.data1 = data.x;
|
||||
packet.flag1 = flag;
|
||||
packet.data2 = data.y;
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)constSmOutOfPlaceChans[index].dst_ + scratchResultOffset);
|
||||
dstPkt[idx + rank * nPktsPerRank].write(data.x, data.y, flag);
|
||||
constSmOutOfPlaceChans[index].write(offset, packet);
|
||||
}
|
||||
}
|
||||
// step 3: get data result from scratch buffer
|
||||
@@ -1029,6 +1033,67 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank,
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void allreduce7(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems) {
|
||||
// This version of allreduce only works for single nodes
|
||||
const int nPeers = nRanksPerNode - 1;
|
||||
const size_t nPkts = nelems;
|
||||
const int nelemsPerRank = nelems / worldSize;
|
||||
const int nPktsPerRank = nelemsPerRank;
|
||||
// flag for packets. Initially 1
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
// thread block & channel info
|
||||
const int nBlocksPerPeer = gridDim.x / nPeers;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LL8Packet);
|
||||
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
||||
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet);
|
||||
size_t scratchResultOffset =
|
||||
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LL8Packet) : 3 * nPkts * sizeof(mscclpp::LL8Packet);
|
||||
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
|
||||
uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int));
|
||||
uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
constSmOutOfPlaceChans[peerIdx].putPackets<mscclpp::LL8Packet>(scratchOffset, srcOffset, nelemsPerRank * sizeof(int),
|
||||
tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
||||
uint32_t data = 0;
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
const int remoteRank = index < rank ? index : index + 1;
|
||||
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank;
|
||||
uint32_t val = dstPkt[idx].read(flag);
|
||||
data += val;
|
||||
}
|
||||
data += src[idx];
|
||||
dst[idx] = data;
|
||||
|
||||
mscclpp::LL8Packet packet;
|
||||
packet.data = data;
|
||||
packet.flag = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LL8Packet) + (idx + rank * nPktsPerRank);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
constSmOutOfPlaceChans[index].write(offset, packet);
|
||||
}
|
||||
}
|
||||
// step 3: get data result from scratch buffer
|
||||
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchResultOffset);
|
||||
const int dstOffset = remoteRank * nPktsPerRank;
|
||||
uint32_t* result = (uint32_t*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
|
||||
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
|
||||
uint32_t data = dstPkt[idx + dstOffset].read(flag);
|
||||
result[idx] = data;
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
class AllReduceTestColl : public BaseTestColl {
|
||||
public:
|
||||
AllReduceTestColl() = default;
|
||||
@@ -1072,6 +1137,10 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
nBlocks = 21;
|
||||
tmpBuff = scratchPacketBuff;
|
||||
nThreadsPerBlock = 512;
|
||||
} else if (kernelNum == 7) {
|
||||
nBlocks = 28;
|
||||
tmpBuff = scratchPacketBuff;
|
||||
nThreadsPerBlock = 1024;
|
||||
} else {
|
||||
nBlocks = std::max(args.nRanksPerNode - 1, 1) * BLOCKS_PER_PEER;
|
||||
tmpBuff = scratchPacketBuff;
|
||||
@@ -1097,6 +1166,9 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
else if (kernelNum == 6) {
|
||||
allreduce6<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank,
|
||||
args.nRanksPerNode, worldSize, paramCount_);
|
||||
} else if (kernelNum == 7) {
|
||||
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank,
|
||||
args.nRanksPerNode, worldSize, paramCount_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1150,7 +1222,8 @@ std::vector<KernelRestriction> AllReduceTestColl::getKernelRestrictions() {
|
||||
16 * worldSize_ /*use ulong2 to transfer data*/,
|
||||
},
|
||||
{5, "allreduce5", false, 1, 4 * worldSize_},
|
||||
{6, "allreduce6", false, 1, 4 * worldSize_}};
|
||||
{6, "allreduce6", false, 1, 4 * worldSize_},
|
||||
{7, "allreduce7", false, 1, 4 * worldSize_}};
|
||||
}
|
||||
|
||||
class AllReduceTestEngine : public BaseTestEngine {
|
||||
@@ -1180,16 +1253,20 @@ class AllReduceTestEngine : public BaseTestEngine {
|
||||
std::shared_ptr<int[]> expectedBuff_;
|
||||
std::vector<mscclpp::SmChannel> smOutOfPlaceChannels_;
|
||||
std::vector<mscclpp::SmChannel> smInPlaceChannels_;
|
||||
std::vector<mscclpp::SmChannel> smOutputPlaceGetChannels_;
|
||||
std::vector<mscclpp::SmChannel> smOutOfPlaceGetChannels_;
|
||||
};
|
||||
|
||||
AllReduceTestEngine::AllReduceTestEngine(const TestArgs& args) : BaseTestEngine(args, "allreduce") {
|
||||
inPlace_ = isInPlace();
|
||||
}
|
||||
|
||||
bool AllReduceTestEngine::isUsePacket() const { return (args_.kernelNum == 2 || args_.kernelNum == 6); }
|
||||
bool AllReduceTestEngine::isUsePacket() const {
|
||||
return (args_.kernelNum == 2 || args_.kernelNum == 6 || args_.kernelNum == 7);
|
||||
}
|
||||
|
||||
bool AllReduceTestEngine::isInPlace() const { return (args_.kernelNum != 2 && args_.kernelNum != 6); }
|
||||
bool AllReduceTestEngine::isInPlace() const {
|
||||
return (args_.kernelNum != 2 && args_.kernelNum != 6 && args_.kernelNum != 7);
|
||||
}
|
||||
|
||||
void AllReduceTestEngine::allocateBuffer() {
|
||||
inputBuff_ = mscclpp::allocExtSharedCuda<int>(args_.maxBytes / sizeof(int));
|
||||
@@ -1211,7 +1288,7 @@ void AllReduceTestEngine::allocateBuffer() {
|
||||
getPacketBuff_ = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(packetBuffNelem);
|
||||
putPacketBuff = putPacketBuff_.get();
|
||||
getPacketBuff = getPacketBuff_.get();
|
||||
} else if (args_.kernelNum == 6) {
|
||||
} else if (args_.kernelNum == 6 || args_.kernelNum == 7) {
|
||||
const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
// 2x for double-buffering, scratchBuff used to store original data and reduced results
|
||||
const size_t scratchBuffNelem = nPacket * 2 /*original data & reduced result */ * 2 /* double buffering*/;
|
||||
@@ -1232,7 +1309,7 @@ void AllReduceTestEngine::setupConnections() {
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
|
||||
|
||||
const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
if (args_.kernelNum == 6) {
|
||||
if (args_.kernelNum == 6 || args_.kernelNum == 7) {
|
||||
const size_t scratchPacketBuffBytes = nPacket * 2 * 2 * sizeof(mscclpp::LLPacket);
|
||||
setupMeshConnections(smOutOfPlaceChannels_, inputBuff_.get(), args_.maxBytes, scratchPacketBuff_.get(),
|
||||
scratchPacketBuffBytes);
|
||||
@@ -1301,14 +1378,14 @@ void AllReduceTestEngine::setupConnections() {
|
||||
CUDATHROW(cudaMemcpyToSymbol(constSmInPlaceChans, smChannelDeviceHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size()));
|
||||
|
||||
setupMeshConnections(smOutputPlaceGetChannels_, inputBuff_.get(), args_.maxBytes, scratchBuff_.get(),
|
||||
args_.maxBytes, ChannelSemantic::GET);
|
||||
if (smOutputPlaceGetChannels_.size() >
|
||||
setupMeshConnections(smOutOfPlaceGetChannels_, inputBuff_.get(), args_.maxBytes, scratchBuff_.get(), args_.maxBytes,
|
||||
ChannelSemantic::GET);
|
||||
if (smOutOfPlaceGetChannels_.size() >
|
||||
sizeof(constSmOutOfPlaceGetChans) / sizeof(DeviceHandle<mscclpp::SmChannel>)) {
|
||||
std::runtime_error("unexpected error");
|
||||
}
|
||||
smChannelDeviceHandles.resize(smOutputPlaceGetChannels_.size());
|
||||
getChannelDeviceHandle(smOutputPlaceGetChannels_, smChannelDeviceHandles);
|
||||
smChannelDeviceHandles.resize(smOutOfPlaceGetChannels_.size());
|
||||
getChannelDeviceHandle(smOutOfPlaceGetChannels_, smChannelDeviceHandles);
|
||||
CUDATHROW(cudaMemcpyToSymbol(constSmOutOfPlaceGetChans, smChannelDeviceHandles.data(),
|
||||
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size()));
|
||||
}
|
||||
|
||||
@@ -428,7 +428,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp
|
||||
|
||||
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
|
||||
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes,
|
||||
ChannelSemantic semantic) {
|
||||
ChannelSemantic semantic, size_t nChannelPerConnection) {
|
||||
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
|
||||
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
|
||||
mscclpp::RegisteredMemory getPacketBufRegMem;
|
||||
@@ -443,19 +443,23 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
|
||||
(outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem;
|
||||
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
|
||||
|
||||
std::unordered_map<size_t, std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
|
||||
std::unordered_map<size_t, std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>>> smSemaphores;
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
smSemaphores.emplace(cid, std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, connections[cid]));
|
||||
for (size_t i = 0; i < nChannelPerConnection; ++i) {
|
||||
smSemaphores[cid].emplace_back(std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, connections[cid]));
|
||||
}
|
||||
}
|
||||
}
|
||||
comm_->setup();
|
||||
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
smChannels.emplace_back(smSemaphores[cid], remoteRegMemories[cid].get(),
|
||||
(outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(),
|
||||
nullptr);
|
||||
for (size_t i = 0; i < nChannelPerConnection; ++i) {
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
smChannels.emplace_back(smSemaphores[cid][i], remoteRegMemories[cid].get(),
|
||||
(outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(),
|
||||
outputBuff);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ class BaseTestEngine {
|
||||
SetupChannelFunc setupChannel = nullptr);
|
||||
void setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff, size_t inputBuffBytes,
|
||||
void* outputBuff = nullptr, size_t outputBuffBytes = 0,
|
||||
ChannelSemantic semantic = ChannelSemantic::PUT);
|
||||
ChannelSemantic semantic = ChannelSemantic::PUT, size_t nChannelPerConnection = 1);
|
||||
void setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels,
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels, void* inputBuff,
|
||||
size_t inputBuffBytes, void* putPacketBuff = nullptr, size_t putPacketBuffBytes = 0,
|
||||
|
||||
Reference in New Issue
Block a user