diff --git a/CMakeLists.txt b/CMakeLists.txt index 66ed4b94..302febab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,11 +83,11 @@ if(USE_CUDA) else() set(CMAKE_HIP_STANDARD 17) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wall -Wextra") - project(mscclpp LANGUAGES CXX HIP) + project(mscclpp LANGUAGES CXX) set(CMAKE_HIP_ARCHITECTURES gfx90a gfx941 gfx942) - set(GPU_LIBRARIES hip::host) + set(GPU_LIBRARIES hip::device) set(GPU_INCLUDE_DIRS ${hip_INCLUDE_DIRS}) endif() diff --git a/docs/quickstart.md b/docs/quickstart.md index f2b12d18..af1bbe5f 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -27,10 +27,24 @@ CMake 3.25 or later is required. ```bash $ git clone https://github.com/microsoft/mscclpp.git $ mkdir -p mscclpp/build && cd mscclpp/build +``` + +For NVIDIA platforms, build MSCCL++ as follows. + +```bash +# For NVIDIA platforms $ cmake -DCMAKE_BUILD_TYPE=Release .. $ make -j ``` +For AMD platforms, use HIPCC instead of the default C++ compiler. Replace `/path/to/hipcc` from the command below into the your HIPCC path. + +```bash +# For AMD platforms +$ CXX=/path/to/hipcc cmake -DCMAKE_BUILD_TYPE=Release .. +$ make -j +``` + ## Install from Source (Libraries and Headers) ```bash diff --git a/include/mscclpp/packet_device.hpp b/include/mscclpp/packet_device.hpp index bf213993..11f63b53 100644 --- a/include/mscclpp/packet_device.hpp +++ b/include/mscclpp/packet_device.hpp @@ -5,6 +5,7 @@ #define MSCCLPP_PACKET_DEVICE_HPP_ #include +#include #include "device.hpp" @@ -14,9 +15,8 @@ #endif // defined(MSCCLPP_DEVICE_COMPILE) namespace mscclpp { - /// LL (low latency) protocol packet. -union alignas(16) LLPacket { +union alignas(16) LL16Packet { // Assume data is written with an atomicity of 8 bytes (IB/RDMA). struct { uint32_t data1; @@ -28,7 +28,7 @@ union alignas(16) LLPacket { #if defined(MSCCLPP_DEVICE_COMPILE) ulonglong2 raw_; - MSCCLPP_DEVICE_INLINE LLPacket() {} + MSCCLPP_DEVICE_INLINE LL16Packet() {} /// Write 8 bytes of data to the packet. /// @param val1 The first 4-byte data to write. @@ -88,34 +88,202 @@ union alignas(16) LLPacket { #endif // defined(MSCCLPP_DEVICE_COMPILE) }; +union alignas(8) LL8Packet { + // Assume data is written with an atomicity of 8 bytes (IB/RDMA). + struct { + uint32_t data; + uint32_t flag; + }; + uint64_t raw_; #if defined(MSCCLPP_DEVICE_COMPILE) -/// Read from the origin and write to the target buffer. -MSCCLPP_DEVICE_INLINE void putPackets(void* targetPtr, uint64_t targetOffset, const void* originPtr, - uint64_t originOffset, uint64_t originBytes, uint32_t threadId, - uint32_t numThreads, uint32_t flag) { + + MSCCLPP_DEVICE_INLINE LL8Packet() {} + + MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) { +#if defined(MSCCLPP_DEVICE_CUDA) + asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" ::"l"(&raw_), "r"(val), "r"(flag)); +#else // !defined(MSCCLPP_DEVICE_CUDA) + uint2 reg = make_uint2(val, flag); + uint64_t* p = reinterpret_cast(®); + atomicStore(&(raw_), *p, memoryOrderRelaxed); +#endif + } + + MSCCLPP_DEVICE_INLINE bool readOnce(uint32_t flag, uint32_t& data) const { +#if defined(MSCCLPP_DEVICE_CUDA) + uint32_t f; + asm volatile("ld.volatile.global.v2.u32 {%0,%1}, [%2];" : "=r"(data), "=r"(f) : "l"(&raw_)); + return (f != flag); +#else // !defined(MSCCLPP_DEVICE_CUDA) + uint64_t reg; + reg = atomicLoad(&(raw_), memoryOrderRelaxed); + uint2* ptr = reinterpret_cast(®); + data = ptr->x; + return (ptr->y != flag); +#endif + } + + MSCCLPP_DEVICE_INLINE uint32_t read(uint32_t flag, int64_t maxSpinCount = 1000000) const { + uint32_t data; + POLL_MAYBE_JAILBREAK(readOnce(flag, data), maxSpinCount); + return data; + } + + /// Clear the packet. + MSCCLPP_DEVICE_INLINE void clear() { raw_ = 0; } +#endif // defined(MSCCLPP_DEVICE_COMPILE) +}; + +using LLPacket = LL16Packet; + +#if defined(MSCCLPP_DEVICE_COMPILE) +/// Read data from the origin and write LL16Packets to the target buffer. +/// +/// @param targetPtr The target buffer. +/// @param targetOffset The offset in the target buffer. +/// @param originPtr The origin buffer. +/// @param originOffset The offset in the origin buffer. +/// @param originBytes The number of bytes to write to the target buffer. +/// @param threadId The thread ID. The thread ID should be less than @p numThreads. +/// @param numThreads The number of threads that call this function. +/// @param flag The flag to write. +/// +MSCCLPP_DEVICE_INLINE void putLL16Packets(void* targetPtr, uint64_t targetOffset, const void* originPtr, + uint64_t originOffset, uint64_t originBytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { // Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes const uint32_t* originBase = (const uint32_t*)((const char*)originPtr + originOffset); - LLPacket* targetBase = (LLPacket*)((char*)targetPtr + targetOffset); + LL16Packet* targetBase = (LL16Packet*)((char*)targetPtr + targetOffset); size_t nElem = originBytes / sizeof(uint64_t); for (size_t i = threadId; i < nElem; i += numThreads) { - LLPacket* pkt = &targetBase[i]; + LL16Packet* pkt = &targetBase[i]; pkt->write(originBase[2 * i], originBase[2 * i + 1], flag); } } -/// Read from the target buffer and write to the origin. -MSCCLPP_DEVICE_INLINE void getPackets(const void* targetPtr, uint64_t targetOffset, void* originPtr, - uint64_t originOffset, uint64_t originBytes, uint32_t threadId, - uint32_t numThreads, uint32_t flag) { +/// Read LL16Packets from the target buffer and write retrieved data to the origin. +/// +/// @param targetPtr The target buffer. +/// @param targetOffset The offset in the target buffer. +/// @param originPtr The origin buffer. +/// @param originOffset The offset in the origin buffer. +/// @param originBytes The number of bytes to write to the target buffer. +/// @param threadId The thread ID. The thread ID should be less than @p numThreads. +/// @param numThreads The number of threads that call this function. +/// @param flag The flag to write. +/// +MSCCLPP_DEVICE_INLINE void getLL16Packets(const void* targetPtr, uint64_t targetOffset, void* originPtr, + uint64_t originOffset, uint64_t originBytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { // Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes - const LLPacket* targetBase = (const LLPacket*)((const char*)targetPtr + targetOffset); + const LL16Packet* targetBase = (const LL16Packet*)((const char*)targetPtr + targetOffset); uint2* originBase = (uint2*)((char*)originPtr + originOffset); size_t nElem = originBytes / sizeof(uint2); for (size_t i = threadId; i < nElem; i += numThreads) { - const LLPacket* pkt = &targetBase[i]; + const LL16Packet* pkt = &targetBase[i]; originBase[i] = pkt->read(flag); } } + +/// Read data from the origin and write LL8Packets to the target buffer. +/// +/// @param targetPtr The target buffer. +/// @param targetOffset The offset in the target buffer. +/// @param originPtr The origin buffer. +/// @param originOffset The offset in the origin buffer. +/// @param originBytes The number of bytes to write to the target buffer. +/// @param threadId The thread ID. The thread ID should be less than @p numThreads. +/// @param numThreads The number of threads that call this function. +/// @param flag The flag to write. +/// +MSCCLPP_DEVICE_INLINE void putLL8Packets(void* targetPtr, uint64_t targetOffset, const void* originPtr, + uint64_t originOffset, uint64_t originBytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { + // Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes + const uint32_t* originBase = (const uint32_t*)((const char*)originPtr + originOffset); + LL8Packet* targetBase = (LL8Packet*)((char*)targetPtr + targetOffset); + size_t nElem = originBytes / sizeof(uint32_t); + for (size_t i = threadId; i < nElem; i += numThreads) { + LL8Packet* pkt = &targetBase[i]; + pkt->write(originBase[i], flag); + } +} + +/// Read LL8Packets from the target buffer and write retrieved data to the origin. +/// +/// @param targetPtr The target buffer. +/// @param targetOffset The offset in the target buffer. +/// @param originPtr The origin buffer. +/// @param originOffset The offset in the origin buffer. +/// @param originBytes The number of bytes to write to the target buffer. +/// @param threadId The thread ID. The thread ID should be less than @p numThreads. +/// @param numThreads The number of threads that call this function. +/// @param flag The flag to write. +/// +MSCCLPP_DEVICE_INLINE void getLL8Packets(const void* targetPtr, uint64_t targetOffset, void* originPtr, + uint64_t originOffset, uint64_t originBytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { + // Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes + const LL8Packet* targetBase = (const LL8Packet*)((const char*)targetPtr + targetOffset); + uint32_t* originBase = (uint32_t*)((char*)originPtr + originOffset); + size_t nElem = originBytes / sizeof(uint32_t); + for (size_t i = threadId; i < nElem; i += numThreads) { + const LL8Packet* pkt = &targetBase[i]; + originBase[i] = pkt->read(flag); + } +} + +/// Read data from the origin and write packets to the target buffer. +/// +/// @param targetPtr The target buffer. +/// @param targetOffset The offset in the target buffer. +/// @param originPtr The origin buffer. +/// @param originOffset The offset in the origin buffer. +/// @param originBytes The number of bytes to write to the target buffer. +/// @param threadId The thread ID. The thread ID should be less than @p numThreads. +/// @param numThreads The number of threads that call this function. +/// @param flag The flag to write. +/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. +/// +template +MSCCLPP_DEVICE_INLINE void putPackets(void* targetPtr, uint64_t targetOffset, const void* originPtr, + uint64_t originOffset, uint64_t originBytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { + if constexpr (std::is_same::value) { + putLL16Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag); + } else if constexpr (std::is_same::value) { + putLL8Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag); + } else { + static_assert(std::is_same::value || std::is_same::value, + "Unsupported packet type"); + } +} + +/// Read packets from the target buffer and write retrieved data to the origin. +/// +/// @param targetPtr The target buffer. +/// @param targetOffset The offset in the target buffer. +/// @param originPtr The origin buffer. +/// @param originOffset The offset in the origin buffer. +/// @param originBytes The number of bytes to read from the origin buffer. +/// @param threadId The thread ID. The thread ID should be less than @p numThreads. +/// @param numThreads The number of threads that call this function. +/// @param flag The flag to read. +/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. +/// +template +MSCCLPP_DEVICE_INLINE void getPackets(const void* targetPtr, uint64_t targetOffset, void* originPtr, + uint64_t originOffset, uint64_t originBytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { + if constexpr (std::is_same::value) { + getLL16Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag); + } else if constexpr (std::is_same::value) { + getLL8Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag); + } else { + static_assert(std::is_same::value || std::is_same::value, + "Unsupported packet type"); + } +} #endif // defined(MSCCLPP_DEVICE_COMPILE) }; // namespace mscclpp diff --git a/include/mscclpp/poll_device.hpp b/include/mscclpp/poll_device.hpp index 0cdb6b01..9ad116f8 100644 --- a/include/mscclpp/poll_device.hpp +++ b/include/mscclpp/poll_device.hpp @@ -10,6 +10,9 @@ #include +#if defined(NDEBUG) +#define __assert_fail(__assertion, __file, __line, __function) ; +#else // !defined(NDEBUG) #if defined(MSCCLPP_DEVICE_HIP) extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line, const char *__function); @@ -17,6 +20,7 @@ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__ extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line, const char *__function) __THROW; #endif // !defined(MSCCLPP_DEVICE_HIP) +#endif // NDEBUG // If a spin is stuck, print a warning and keep spinning. #define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \ diff --git a/include/mscclpp/sm_channel_device.hpp b/include/mscclpp/sm_channel_device.hpp index 29993f8e..e49a431b 100644 --- a/include/mscclpp/sm_channel_device.hpp +++ b/include/mscclpp/sm_channel_device.hpp @@ -211,10 +211,12 @@ struct SmChannelDeviceHandle { /// @param threadId The index of the current thread among all threads running this function. This is different from /// the `threadIdx` in CUDA. /// @param numThreads The total number of threads that run this function. + /// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. /// + template MSCCLPP_DEVICE_INLINE void putPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes, uint32_t threadId, uint32_t numThreads, uint32_t flag) { - mscclpp::putPackets(dst_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag); + mscclpp::putPackets(dst_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag); } /// Retrieve data from @ref LLPacket in the local packet buffer (target) and write it on the local data (origin). @@ -227,10 +229,13 @@ struct SmChannelDeviceHandle { /// @param threadId The index of the current thread among all threads running this function. This is different from /// the `threadIdx` in CUDA. /// @param numThreads The total number of threads that run this function. + /// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. /// + template MSCCLPP_DEVICE_INLINE void getPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes, uint32_t threadId, uint32_t numThreads, uint32_t flag) { - mscclpp::getPackets(getPacketBuffer_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag); + mscclpp::getPackets(getPacketBuffer_, targetOffset, src_, originOffset, originBytes, threadId, + numThreads, flag); } /// Signal the remote semaphore. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ef85cde5..0268af1c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -10,7 +10,7 @@ set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include) if(USE_ROCM) file(GLOB_RECURSE CU_SOURCES CONFIGURE_DEPENDS *.cu) - set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE HIP) + set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE CXX) endif() function(add_test_executable name sources) diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index d839b640..e934dee4 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -149,6 +149,8 @@ class SmChannelOneToOneTest : public CommunicatorTestBase { void setupMeshConnections(std::vector& smChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff = nullptr, size_t outputBuffBytes = 0); + using PacketPingPongKernelWrapper = std::function; + void packetPingPongTest(const std::string testName, PacketPingPongKernelWrapper kernelWrapper); std::unordered_map> smSemaphores; }; diff --git a/test/mp_unit/sm_channel_tests.cu b/test/mp_unit/sm_channel_tests.cu index ea524105..45c5fa64 100644 --- a/test/mp_unit/sm_channel_tests.cu +++ b/test/mp_unit/sm_channel_tests.cu @@ -70,6 +70,61 @@ void SmChannelOneToOneTest::setupMeshConnections(std::vector __constant__ DeviceHandle gChannelOneToOneTestConstSmChans; +void SmChannelOneToOneTest::packetPingPongTest(const std::string testName, PacketPingPongKernelWrapper kernelWrapper) { + if (gEnv->rank >= numRanksToUse) return; + + const int nElem = 4 * 1024 * 1024; + const int defaultNTries = 1000; + + std::vector smChannels; + std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); + std::shared_ptr intermBuff = mscclpp::allocExtSharedCuda(nElem * 2); + setupMeshConnections(smChannels, buff.get(), nElem * sizeof(int), intermBuff.get(), nElem * 2 * sizeof(int)); + std::vector> deviceHandles(smChannels.size()); + std::transform(smChannels.begin(), smChannels.end(), deviceHandles.begin(), + [](const mscclpp::SmChannel& smChan) { return mscclpp::deviceHandle(smChan); }); + + ASSERT_EQ(smChannels.size(), 1); + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstSmChans, deviceHandles.data(), + sizeof(DeviceHandle))); + + std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + + // The least nelem is 2 for packet ping pong + kernelWrapper(buff.get(), gEnv->rank, 2, ret.get(), defaultNTries); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + *ret = 0; + + kernelWrapper(buff.get(), gEnv->rank, 1024, ret.get(), defaultNTries); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); + *ret = 0; + + kernelWrapper(buff.get(), gEnv->rank, 1024 * 1024, ret.get(), defaultNTries); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); + *ret = 0; + + kernelWrapper(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get(), defaultNTries); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); + *ret = 0; + + int nTries = 1000000; + communicator->bootstrap()->barrier(); + mscclpp::Timer timer; + kernelWrapper(buff.get(), gEnv->rank, 1024, ret.get(), nTries); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + communicator->bootstrap()->barrier(); + + if (gEnv->rank == 0) { + std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)(nTries) << " us/iter\n"; + } +} + __global__ void kernelSmPutPingPong(int* buff, int rank, int nElem, int* ret) { DeviceHandle& smChan = gChannelOneToOneTestConstSmChans; volatile int* sendBuff = (volatile int*)buff; @@ -238,17 +293,53 @@ TEST_F(SmChannelOneToOneTest, GetPingPong) { EXPECT_EQ(*ret, 0); } -__global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret) { +__global__ void kernelSmLL8PacketPingPong(int* buff, int rank, int nElem, int* ret, int nTries) { if (rank > 1) return; DeviceHandle& smChan = gChannelOneToOneTestConstSmChans; volatile int* sendBuff = (volatile int*)buff; - int nTries = 1000; int putOffset = (rank == 0) ? 0 : 10000000; int getOffset = (rank == 0) ? 10000000 : 0; for (int i = 0; i < nTries; i++) { uint64_t flag = (uint64_t)i + 1; + // rank=0: 0, 1, 0, 1, ... + // rank=1: 1, 0, 1, 0, ... + if ((rank ^ (i & 1)) == 0) { + // If each thread writes 8 bytes at once, we don't need a barrier before putPackets(). + for (int j = threadIdx.x; j < nElem; j += blockDim.x) { + sendBuff[j] = putOffset + i + j; + // sendBuff[2 * j + 1] = putOffset + i + 2 * j + 1; + } + // __syncthreads(); + smChan.putPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag); + } else { + smChan.getPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag); + // If each thread reads 8 bytes at once, we don't need a barrier after getPackets(). + // __syncthreads(); + for (int j = threadIdx.x; j < nElem; j += blockDim.x) { + if (sendBuff[j] != getOffset + i + j) { + // printf("ERROR: rank = %d, sendBuff[%d] = %d, expected %d. Skipping following errors\n", rank, 2 * j, + // sendBuff[2 * j], getOffset + i + 2 * j); + *ret = 1; + break; + } + } + } + // Make sure all threads are done in this iteration + __syncthreads(); + } +} + +__global__ void kernelSmLL16PacketPingPong(int* buff, int rank, int nElem, int* ret, int nTries) { + if (rank > 1) return; + + DeviceHandle& smChan = gChannelOneToOneTestConstSmChans; + volatile int* sendBuff = (volatile int*)buff; + int putOffset = (rank == 0) ? 0 : 10000000; + int getOffset = (rank == 0) ? 10000000 : 0; + for (int i = 0; i < nTries; i++) { + uint64_t flag = (uint64_t)i + 1; // rank=0: 0, 1, 0, 1, ... // rank=1: 1, 0, 1, 0, ... if ((rank ^ (i & 1)) == 0) { @@ -258,9 +349,9 @@ __global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret) sendBuff[2 * j + 1] = putOffset + i + 2 * j + 1; } // __syncthreads(); - smChan.putPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag); + smChan.putPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag); } else { - smChan.getPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag); + smChan.getPackets(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag); // If each thread reads 8 bytes at once, we don't need a barrier after getPackets(). // __syncthreads(); for (int j = threadIdx.x; j < nElem / 2; j += blockDim.x) { @@ -283,46 +374,16 @@ __global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret) } } -TEST_F(SmChannelOneToOneTest, PacketPingPong) { - if (gEnv->rank >= numRanksToUse) return; - - const int nElem = 4 * 1024 * 1024; - - std::vector smChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); - std::shared_ptr intermBuff = mscclpp::allocExtSharedCuda(nElem * 2); - setupMeshConnections(smChannels, buff.get(), nElem * sizeof(int), intermBuff.get(), nElem * 2 * sizeof(int)); - std::vector> deviceHandles(smChannels.size()); - std::transform(smChannels.begin(), smChannels.end(), deviceHandles.begin(), - [](const mscclpp::SmChannel& smChan) { return mscclpp::deviceHandle(smChan); }); - - ASSERT_EQ(smChannels.size(), 1); - MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstSmChans, deviceHandles.data(), - sizeof(DeviceHandle))); - - std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); - - // The least nelem is 2 for packet ping pong - kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 2, ret.get()); - MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); - - EXPECT_EQ(*ret, 0); - *ret = 0; - - kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get()); - MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); - - EXPECT_EQ(*ret, 0); - *ret = 0; - - kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, ret.get()); - MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); - - EXPECT_EQ(*ret, 0); - *ret = 0; - - kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get()); - MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); - - EXPECT_EQ(*ret, 0); +TEST_F(SmChannelOneToOneTest, LL8PacketPingPong) { + auto kernelSmLL8PacketPingPongWrapper = [](int* buff, int rank, int nElem, int* ret, int nTries) { + kernelSmLL8PacketPingPong<<<1, 1024>>>(buff, rank, nElem, ret, nTries); + }; + packetPingPongTest("smLL8PacketPingPong", kernelSmLL8PacketPingPongWrapper); +} + +TEST_F(SmChannelOneToOneTest, LL16PacketPingPong) { + auto kernelSmLL16PacketPingPongWrapper = [](int* buff, int rank, int nElem, int* ret, int nTries) { + kernelSmLL16PacketPingPong<<<1, 1024>>>(buff, rank, nElem, ret, nTries); + }; + packetPingPongTest("smLL16PacketPingPong", kernelSmLL16PacketPingPongWrapper); } diff --git a/test/mscclpp-test/CMakeLists.txt b/test/mscclpp-test/CMakeLists.txt index cbbdfea6..e2ec8c2e 100644 --- a/test/mscclpp-test/CMakeLists.txt +++ b/test/mscclpp-test/CMakeLists.txt @@ -6,7 +6,7 @@ FetchContent_MakeAvailable(json) function(add_mscclpp_test_executable name sources) if(USE_ROCM) - set_source_files_properties(${sources} PROPERTIES LANGUAGE HIP) + set_source_files_properties(${sources} PROPERTIES LANGUAGE CXX) endif() add_executable(${name} ${sources} common.cc) target_link_libraries(${name} ${TEST_LIBS_COMMON} MPI::MPI_CXX nlohmann_json::nlohmann_json) diff --git a/test/mscclpp-test/allgather_test.cu b/test/mscclpp-test/allgather_test.cu index 4b2eff78..5c101bbd 100644 --- a/test/mscclpp-test/allgather_test.cu +++ b/test/mscclpp-test/allgather_test.cu @@ -23,7 +23,9 @@ using DeviceHandle = mscclpp::DeviceHandle; __constant__ DeviceHandle constProxyChans[16]; __constant__ DeviceHandle constRawProxyChan[16]; -__constant__ DeviceHandle constSmChans[8]; +__constant__ DeviceHandle constSmChans[512]; +__constant__ DeviceHandle constSmOutOfPlaceChans[16]; +__device__ uint64_t globalFlag; __global__ void allgather0(int rank, size_t nelemsPerGPU) { int warpId = threadIdx.x / WARP_SIZE; @@ -288,6 +290,215 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne nBlocksForLocalAllGather); } +__global__ void __launch_bounds__(1024, 1) + allgather5(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) { + const size_t nBlock = gridDim.x; + if (blockIdx.x >= nBlock) return; + + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t lid = tid % WARP_SIZE; + const size_t wid = tid / WARP_SIZE; + + const size_t nThread = blockDim.x * nBlock; + const size_t nWarp = nThread / WARP_SIZE; + const size_t nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + auto smChans = constSmChans + chanOffset; + + if (wid < nPeer && lid == 0) { + smChans[wid].relaxedSignal(); + smChans[wid].wait(); + } + __syncthreads(); + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU * nPeer; + size_t unitBytesPerThread; + if (bytes >= nThread * 64) { + unitBytesPerThread = 64; + } else { + unitBytesPerThread = 16; + } + const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; + const size_t unitBytes = unitBytesPerWarp * nWarp; + const size_t nLoop = bytes / unitBytes; + + if (nLoop > 0) { + // First loop unrolling + const size_t peerIdx = wid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + + for (size_t i = 1; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + smChans[peerIdx].get<16, true>(offset, remainBytes, lid, WARP_SIZE); + } + } +} + +__global__ void __launch_bounds__(1024, 1) + allgather6(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) { + const size_t nBlock = gridDim.x; + if (blockIdx.x >= nBlock) return; + + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t lid = tid % WARP_SIZE; + const size_t wid = tid / WARP_SIZE; + + const size_t nThread = blockDim.x * nBlock; + const size_t nWarp = nThread / WARP_SIZE; + const size_t nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + auto smChans = constSmChans + chanOffset; + + if (wid < nPeer && lid == 0) { + smChans[wid].relaxedSignal(); + smChans[wid].wait(); + } + __syncthreads(); + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU * nPeer; + size_t unitBytesPerThread; + if (bytes >= nThread * 64) { + unitBytesPerThread = 64; + } else { + unitBytesPerThread = 16; + } + const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; + const size_t unitBytes = unitBytesPerWarp * nWarp; + const size_t nLoop = bytes / unitBytes; + + if (nLoop > 0) { + // First loop unrolling + const size_t peerIdx = wid % nPeer; + const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + + for (size_t i = 1; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * rank + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + smChans[peerIdx].put<16, true>(offset, remainBytes, lid, WARP_SIZE); + } + } +} + +__global__ void __launch_bounds__(1024, 1) + allgather7(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) { + const size_t nBlock = gridDim.x; + if (blockIdx.x >= nBlock) return; + + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t lid = tid % WARP_SIZE; + const size_t wid = tid / WARP_SIZE; + + const size_t nThread = blockDim.x * nBlock; + const size_t nWarp = nThread / WARP_SIZE; + const size_t nPeer = nRanksPerNode - 1; + auto smChans = constSmOutOfPlaceChans; + + const uint32_t flag = (uint32_t)globalFlag; + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU * nPeer; + size_t unitBytesPerThread = 8; + const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; + const size_t unitBytes = unitBytesPerWarp * nWarp; + const size_t nLoop = bytes / unitBytes; + + // double buffering + const size_t scratchOffset = (flag & 1) ? 0 : bytesPerGPU * nRanksPerNode * 2; + + if (nLoop > 0) { + // First loop unrolling + const size_t peerIdx = wid % nPeer; + const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].putPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag); + } + + if (nLoop > 0) { + // First loop unrolling + const size_t peerIdx = wid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag); + } + + for (size_t i = 1; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].putPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag); + } + + for (size_t i = 1; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag); + } + + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * rank + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + smChans[peerIdx].putPackets(scratchOffset + offset * 2, offset, remainBytes, lid, WARP_SIZE, flag); + } + } + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + smChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, remainBytes, lid, WARP_SIZE, flag); + } + } + + if (threadIdx.x == 0 && blockIdx.x == 0) { + globalFlag += 1; + } +} + class AllGatherProxyService : public mscclpp::BaseProxyService { public: AllGatherProxyService(int worldSize, int rank, int cudaDevice); @@ -387,6 +598,15 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) { if (kernelNum == 4) { nBlocks = 21; nThreads = 1024; + } else if (kernelNum == 5) { + nBlocks = 24; + nThreads = 1024; + } else if (kernelNum == 6) { + nBlocks = 24; + nThreads = 1024; + } else if (kernelNum == 7) { + nBlocks = 4; + nThreads = 896; } else { nBlocks = 1; nThreads = WARP_SIZE * (worldSize - 1); @@ -401,6 +621,12 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) { allgather3<<>>(); } else if (kernelNum == 4) { allgather4<<>>(rank, worldSize, nRanksPerNode, paramCount_); + } else if (kernelNum == 5) { + allgather5<<>>(rank, worldSize, nRanksPerNode, paramCount_); + } else if (kernelNum == 6) { + allgather6<<>>(rank, worldSize, nRanksPerNode, paramCount_); + } else if (kernelNum == 7) { + allgather7<<>>(rank, worldSize, nRanksPerNode, paramCount_); } } @@ -453,7 +679,10 @@ std::vector AllGatherTestColl::getKernelRestrictions() { {1, "allgather1", false, 1, 4 * worldSize_}, {2, "allgather2", true, 3, 4 * worldSize_}, {3, "allgather3", true, 1, 4 * worldSize_}, - {4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/}}; + {4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/}, + {5, "allgather5", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/}, + {6, "allgather6", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/}, + {7, "allgather7", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/}}; } class AllGatherTestEngine : public BaseTestEngine { @@ -474,7 +703,9 @@ class AllGatherTestEngine : public BaseTestEngine { std::shared_ptr sendBuff_; std::shared_ptr expectedBuff_; + std::shared_ptr scratchPacketBuff_; std::vector smChannels_; + std::vector smOutOfPlaceChannels_; }; AllGatherTestEngine::AllGatherTestEngine(const TestArgs& args) : BaseTestEngine(args, "allgather") {} @@ -482,6 +713,12 @@ AllGatherTestEngine::AllGatherTestEngine(const TestArgs& args) : BaseTestEngine( void AllGatherTestEngine::allocateBuffer() { sendBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); expectedBuff_ = std::shared_ptr(new int[args_.maxBytes / sizeof(int)]); + if (args_.kernelNum == 7) { + const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t); + // 2x for double-buffering, scratchBuff used to store original data and reduced results + const size_t scratchBuffNelem = nPacket * 2 /*original data & reduced result */ * 2 /* double buffering*/; + scratchPacketBuff_ = mscclpp::allocExtSharedCuda(scratchBuffNelem); + } } void AllGatherTestEngine::setupConnections() { @@ -494,7 +731,7 @@ void AllGatherTestEngine::setupConnections() { CUDATHROW(cudaMemcpyToSymbol(constProxyChans, devProxyChannels.data(), sizeof(DeviceHandle) * devProxyChannels.size())); - setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes); + setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes, nullptr, 0, ChannelSemantic::PUT, 64); std::vector> smChannelHandles(smChannels_.size()); if (smChannels_.size() > sizeof(constSmChans) / sizeof(DeviceHandle)) { std::runtime_error("unexpected error"); @@ -503,6 +740,21 @@ void AllGatherTestEngine::setupConnections() { [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); CUDATHROW(cudaMemcpyToSymbol(constSmChans, smChannelHandles.data(), sizeof(DeviceHandle) * smChannelHandles.size())); + + if (args_.kernelNum == 7) { + const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t); + const size_t scratchPacketBuffBytes = nPacket * 2 * 2 * sizeof(mscclpp::LLPacket); + setupMeshConnections(smOutOfPlaceChannels_, sendBuff_.get(), args_.maxBytes, scratchPacketBuff_.get(), + scratchPacketBuffBytes); + std::vector> smOutOfPlaceChannelHandles(smOutOfPlaceChannels_.size()); + if (smOutOfPlaceChannels_.size() > sizeof(constSmOutOfPlaceChans) / sizeof(DeviceHandle)) { + std::runtime_error("unexpected error"); + } + std::transform(smOutOfPlaceChannels_.begin(), smOutOfPlaceChannels_.end(), smOutOfPlaceChannelHandles.begin(), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + CUDATHROW(cudaMemcpyToSymbol(constSmOutOfPlaceChans, smOutOfPlaceChannelHandles.data(), + sizeof(DeviceHandle) * smOutOfPlaceChannelHandles.size())); + } } else { auto service = std::dynamic_pointer_cast(chanService_); setupMeshConnections(devProxyChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0, diff --git a/test/mscclpp-test/allreduce_test.cu b/test/mscclpp-test/allreduce_test.cu index 2748681b..cbedcefd 100644 --- a/test/mscclpp-test/allreduce_test.cu +++ b/test/mscclpp-test/allreduce_test.cu @@ -970,9 +970,8 @@ __global__ void allreduce5(int* buff, int rank, int nRanksPerNode, int worldSize __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize, size_t nelems) { // This version of allreduce only works for single nodes - if (worldSize != nRanksPerNode) return; const int nPeers = nRanksPerNode - 1; - const int nPkts = nelems / 2; + const size_t nPkts = nelems / 2; const int nelemsPerRank = nelems / worldSize; const int nPktsPerRank = nelemsPerRank / 2; // flag for packets. Initially 1 @@ -982,7 +981,6 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, const int localBlockIdx = blockIdx.x % nBlocksPerPeer; const int peerIdx = blockIdx.x / nBlocksPerPeer; const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; - DeviceHandle smChan = constSmOutOfPlaceChans[peerIdx]; const int tid = threadIdx.x + localBlockIdx * blockDim.x; // double buffering size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket); @@ -995,7 +993,8 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); // step 1: write to scratch buffer - smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag); + constSmOutOfPlaceChans[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, + blockDim.x * nBlocksPerPeer, flag); // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { uint2 data = make_uint2(0, 0); @@ -1008,11 +1007,16 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, } data.x += src[idx].x; data.y += src[idx].y; - dst[idx].x = data.x; - dst[idx].y = data.y; + dst[idx] = data; + + mscclpp::LLPacket packet; + packet.data1 = data.x; + packet.flag1 = flag; + packet.data2 = data.y; + packet.flag2 = flag; + size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank); for (int index = 0; index < nPeers; index++) { - mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)constSmOutOfPlaceChans[index].dst_ + scratchResultOffset); - dstPkt[idx + rank * nPktsPerRank].write(data.x, data.y, flag); + constSmOutOfPlaceChans[index].write(offset, packet); } } // step 3: get data result from scratch buffer @@ -1029,6 +1033,67 @@ __global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, } } +__global__ void allreduce7(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize, + size_t nelems) { + // This version of allreduce only works for single nodes + const int nPeers = nRanksPerNode - 1; + const size_t nPkts = nelems; + const int nelemsPerRank = nelems / worldSize; + const int nPktsPerRank = nelemsPerRank; + // flag for packets. Initially 1 + const uint32_t flag = (uint32_t)globalFlag; + // thread block & channel info + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int peerIdx = blockIdx.x / nBlocksPerPeer; + const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; + // double buffering + size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LL8Packet); + void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset); + size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet); + size_t scratchResultOffset = + (flag & 1) ? 2 * nPkts * sizeof(mscclpp::LL8Packet) : 3 * nPkts * sizeof(mscclpp::LL8Packet); + size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int); + uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int)); + uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); + + // step 1: write to scratch buffer + constSmOutOfPlaceChans[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), + tid, blockDim.x * nBlocksPerPeer, flag); + // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { + uint32_t data = 0; + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank; + uint32_t val = dstPkt[idx].read(flag); + data += val; + } + data += src[idx]; + dst[idx] = data; + + mscclpp::LL8Packet packet; + packet.data = data; + packet.flag = flag; + size_t offset = scratchResultOffset / sizeof(mscclpp::LL8Packet) + (idx + rank * nPktsPerRank); + for (int index = 0; index < nPeers; index++) { + constSmOutOfPlaceChans[index].write(offset, packet); + } + } + // step 3: get data result from scratch buffer + mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchResultOffset); + const int dstOffset = remoteRank * nPktsPerRank; + uint32_t* result = (uint32_t*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int)); + for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) { + uint32_t data = dstPkt[idx + dstOffset].read(flag); + result[idx] = data; + } + if (threadIdx.x == 0 && blockIdx.x == 0) { + globalFlag += 1; + } +} + class AllReduceTestColl : public BaseTestColl { public: AllReduceTestColl() = default; @@ -1072,6 +1137,10 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) { nBlocks = 21; tmpBuff = scratchPacketBuff; nThreadsPerBlock = 512; + } else if (kernelNum == 7) { + nBlocks = 28; + tmpBuff = scratchPacketBuff; + nThreadsPerBlock = 1024; } else { nBlocks = std::max(args.nRanksPerNode - 1, 1) * BLOCKS_PER_PEER; tmpBuff = scratchPacketBuff; @@ -1097,6 +1166,9 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) { else if (kernelNum == 6) { allreduce6<<>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank, args.nRanksPerNode, worldSize, paramCount_); + } else if (kernelNum == 7) { + allreduce7<<>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank, + args.nRanksPerNode, worldSize, paramCount_); } } @@ -1150,7 +1222,8 @@ std::vector AllReduceTestColl::getKernelRestrictions() { 16 * worldSize_ /*use ulong2 to transfer data*/, }, {5, "allreduce5", false, 1, 4 * worldSize_}, - {6, "allreduce6", false, 1, 4 * worldSize_}}; + {6, "allreduce6", false, 1, 4 * worldSize_}, + {7, "allreduce7", false, 1, 4 * worldSize_}}; } class AllReduceTestEngine : public BaseTestEngine { @@ -1180,16 +1253,20 @@ class AllReduceTestEngine : public BaseTestEngine { std::shared_ptr expectedBuff_; std::vector smOutOfPlaceChannels_; std::vector smInPlaceChannels_; - std::vector smOutputPlaceGetChannels_; + std::vector smOutOfPlaceGetChannels_; }; AllReduceTestEngine::AllReduceTestEngine(const TestArgs& args) : BaseTestEngine(args, "allreduce") { inPlace_ = isInPlace(); } -bool AllReduceTestEngine::isUsePacket() const { return (args_.kernelNum == 2 || args_.kernelNum == 6); } +bool AllReduceTestEngine::isUsePacket() const { + return (args_.kernelNum == 2 || args_.kernelNum == 6 || args_.kernelNum == 7); +} -bool AllReduceTestEngine::isInPlace() const { return (args_.kernelNum != 2 && args_.kernelNum != 6); } +bool AllReduceTestEngine::isInPlace() const { + return (args_.kernelNum != 2 && args_.kernelNum != 6 && args_.kernelNum != 7); +} void AllReduceTestEngine::allocateBuffer() { inputBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); @@ -1211,7 +1288,7 @@ void AllReduceTestEngine::allocateBuffer() { getPacketBuff_ = mscclpp::allocExtSharedCuda(packetBuffNelem); putPacketBuff = putPacketBuff_.get(); getPacketBuff = getPacketBuff_.get(); - } else if (args_.kernelNum == 6) { + } else if (args_.kernelNum == 6 || args_.kernelNum == 7) { const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t); // 2x for double-buffering, scratchBuff used to store original data and reduced results const size_t scratchBuffNelem = nPacket * 2 /*original data & reduced result */ * 2 /* double buffering*/; @@ -1232,7 +1309,7 @@ void AllReduceTestEngine::setupConnections() { std::vector> proxyChannels; const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t); - if (args_.kernelNum == 6) { + if (args_.kernelNum == 6 || args_.kernelNum == 7) { const size_t scratchPacketBuffBytes = nPacket * 2 * 2 * sizeof(mscclpp::LLPacket); setupMeshConnections(smOutOfPlaceChannels_, inputBuff_.get(), args_.maxBytes, scratchPacketBuff_.get(), scratchPacketBuffBytes); @@ -1301,14 +1378,14 @@ void AllReduceTestEngine::setupConnections() { CUDATHROW(cudaMemcpyToSymbol(constSmInPlaceChans, smChannelDeviceHandles.data(), sizeof(DeviceHandle) * smChannelDeviceHandles.size())); - setupMeshConnections(smOutputPlaceGetChannels_, inputBuff_.get(), args_.maxBytes, scratchBuff_.get(), - args_.maxBytes, ChannelSemantic::GET); - if (smOutputPlaceGetChannels_.size() > + setupMeshConnections(smOutOfPlaceGetChannels_, inputBuff_.get(), args_.maxBytes, scratchBuff_.get(), args_.maxBytes, + ChannelSemantic::GET); + if (smOutOfPlaceGetChannels_.size() > sizeof(constSmOutOfPlaceGetChans) / sizeof(DeviceHandle)) { std::runtime_error("unexpected error"); } - smChannelDeviceHandles.resize(smOutputPlaceGetChannels_.size()); - getChannelDeviceHandle(smOutputPlaceGetChannels_, smChannelDeviceHandles); + smChannelDeviceHandles.resize(smOutOfPlaceGetChannels_.size()); + getChannelDeviceHandle(smOutOfPlaceGetChannels_, smChannelDeviceHandles); CUDATHROW(cudaMemcpyToSymbol(constSmOutOfPlaceGetChans, smChannelDeviceHandles.data(), sizeof(DeviceHandle) * smChannelDeviceHandles.size())); } diff --git a/test/mscclpp-test/common.cc b/test/mscclpp-test/common.cc index c5653b3f..9c52f9f4 100644 --- a/test/mscclpp-test/common.cc +++ b/test/mscclpp-test/common.cc @@ -428,7 +428,7 @@ void BaseTestEngine::setupMeshConnections(std::vector& smChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes, - ChannelSemantic semantic) { + ChannelSemantic semantic, size_t nChannelPerConnection) { const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory getPacketBufRegMem; @@ -443,19 +443,23 @@ void BaseTestEngine::setupMeshConnections(std::vector& smCha (outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem; setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories); - std::unordered_map> smSemaphores; + std::unordered_map>> smSemaphores; for (size_t cid = 0; cid < connections.size(); ++cid) { if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { - smSemaphores.emplace(cid, std::make_shared(*comm_, connections[cid])); + for (size_t i = 0; i < nChannelPerConnection; ++i) { + smSemaphores[cid].emplace_back(std::make_shared(*comm_, connections[cid])); + } } } comm_->setup(); - for (size_t cid = 0; cid < connections.size(); ++cid) { - if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { - smChannels.emplace_back(smSemaphores[cid], remoteRegMemories[cid].get(), - (outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(), - nullptr); + for (size_t i = 0; i < nChannelPerConnection; ++i) { + for (size_t cid = 0; cid < connections.size(); ++cid) { + if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { + smChannels.emplace_back(smSemaphores[cid][i], remoteRegMemories[cid].get(), + (outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(), + outputBuff); + } } } } diff --git a/test/mscclpp-test/common.hpp b/test/mscclpp-test/common.hpp index 665ff911..7e3e8c42 100644 --- a/test/mscclpp-test/common.hpp +++ b/test/mscclpp-test/common.hpp @@ -118,7 +118,7 @@ class BaseTestEngine { SetupChannelFunc setupChannel = nullptr); void setupMeshConnections(std::vector& smChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff = nullptr, size_t outputBuffBytes = 0, - ChannelSemantic semantic = ChannelSemantic::PUT); + ChannelSemantic semantic = ChannelSemantic::PUT, size_t nChannelPerConnection = 1); void setupMeshConnections(std::vector& smChannels, std::vector>& proxyChannels, void* inputBuff, size_t inputBuffBytes, void* putPacketBuff = nullptr, size_t putPacketBuffBytes = 0,