Revised MemoryChannel interfaces (#508)

* Moved the `MemoryChannel::copy()` method out of the `MemoryChannel` as
a standalone function.
* Renamed `mscclpp::putPackets()` and `mscclpp::getPackets()` to
`mscclpp::copyToPackets()` and `mscclpp::copyFromPackets()` respectively
for consistency.
* Renamed `MemoryChannel::getPackets()` to
`MemoryChannel::unpackPackets()` for clarity. Renamed `getPacketBuffer`
to `packetBuffer`.
* Added the `MemoryChannel::unpackPacket()` method that unpacks one
packet in the buffer.
* Added the `BaseMemoryChannel` class that only contains a semaphore
without memory addresses.
* Removed the `MemoryDevice2DeviceSemaphoreDeviceHandle::signalPacket()`
method that is lacking use cases.
This commit is contained in:
Changho Hwang
2025-04-24 17:02:56 -07:00
committed by GitHub
parent 9df2bdb2bf
commit 710f6686dc
19 changed files with 518 additions and 499 deletions

View File

@@ -53,10 +53,10 @@ __global__ void __launch_bounds__(1024, 1)
char* src = reinterpret_cast<char*>(memChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp;
memChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
mscclpp::copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
mscclpp::copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
} else {
memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
}
@@ -71,10 +71,10 @@ __global__ void __launch_bounds__(1024, 1)
char* src = reinterpret_cast<char*>(memChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
memChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
mscclpp::copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
mscclpp::copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
} else {
memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
}
@@ -93,10 +93,8 @@ __global__ void __launch_bounds__(1024, 1)
char* dst = reinterpret_cast<char*>(memChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(memChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
memChans[peerIdx].copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
mscclpp::copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE);
mscclpp::copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE);
} else {
memChans[peerIdx].put<16, true>(offset + channelOutOffset, remainBytes, lid, WARP_SIZE);
}

View File

@@ -56,13 +56,13 @@ __global__ void __launch_bounds__(1024, 1)
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(memChans[peerIdx].dst_); // Peer's scratchbuff.
memChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
mscclpp::copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) memChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
memChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
mscclpp::copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
} else { // rank != root.
@@ -70,8 +70,7 @@ __global__ void __launch_bounds__(1024, 1)
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
memChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
mscclpp::copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
}
@@ -89,22 +88,21 @@ __global__ void __launch_bounds__(1024, 1)
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(memChans[peerIdx].dst_); // Peer's scratchbuff.
memChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
mscclpp::copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) memChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
memChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
mscclpp::copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) memChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
memChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
mscclpp::copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock, threadIdx.x,
blockDim.x);
}
}
@@ -117,22 +115,20 @@ __global__ void __launch_bounds__(1024, 1)
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(memChans[peerIdx].dst_); // Peer's scratchbuff.
memChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x,
blockDim.x);
mscclpp::copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) memChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
memChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
mscclpp::copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) memChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
memChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
mscclpp::copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x, blockDim.x);
}
} // remainBytes > 0.
}

View File

@@ -88,8 +88,8 @@ __device__ void gpuKernel(mscclpp::MemoryChannelDeviceHandle* memChans, int flag
// Running on rank 1
__device__ void gpuKernel(mscclpp::MemoryChannelDeviceHandle* memChans, int flag) {
memChans[0].getPackets(/*dstOffset=*/ 0, /*srcOffset=*/ 0, /*size=*/ 1024, /*threadId*/ threadIdx.x, /*numThreads*/ blockDim.x,
/*flag=*/ flag);
memChans[0].unpackPackets(/*dstOffset=*/ 0, /*srcOffset=*/ 0, /*size=*/ 1024, /*threadId*/ threadIdx.x, /*numThreads*/ blockDim.x,
/*flag=*/ flag);
// Data is ready to use
}
```

View File

@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_ASSERT_DEVICE_HPP_
#define MSCCLPP_ASSERT_DEVICE_HPP_
#include "device.hpp"
#if defined(MSCCLPP_DEVICE_COMPILE)
#include <cstdint>
#if !defined(DEBUG_BUILD)
#define __assert_fail(__assertion, __file, __line, __function) ;
namespace mscclpp {
MSCCLPP_DEVICE_INLINE void assert_device(bool cond, const char* msg) {}
} // namespace mscclpp
#else // defined(DEBUG_BUILD)
#if defined(MSCCLPP_DEVICE_HIP)
extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
const char *__function);
#else // !defined(MSCCLPP_DEVICE_HIP)
extern "C" __host__ __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
const char *__function) __THROW;
#endif // !defined(MSCCLPP_DEVICE_HIP)
namespace mscclpp {
MSCCLPP_DEVICE_INLINE void assert_device(bool cond, const char *msg) {
if (!cond) {
__assert_fail(msg, __FILE__, __LINE__, __PRETTY_FUNCTION__);
}
}
} // namespace mscclpp
#endif // !defined(DEBUG_BUILD)
#endif // defined(MSCCLPP_DEVICE_COMPILE)
#endif // MSCCLPP_ASSERT_DEVICE_HPP_

View File

@@ -0,0 +1,187 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_COPY_DEVICE_HPP_
#define MSCCLPP_COPY_DEVICE_HPP_
#include <cstdint>
#include <type_traits>
#include "device.hpp"
#if defined(MSCCLPP_DEVICE_COMPILE)
#include "packet_device.hpp"
#endif // defined(MSCCLPP_DEVICE_COMPILE)
namespace mscclpp {
#if defined(MSCCLPP_DEVICE_COMPILE)
namespace detail {
/// Copy aligned elements from the source memory to the destination memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
/// elements.
///
/// @param dst The destination address.
/// @param src The source address.
/// @param numElems The number of elements to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different
/// from the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <typename T>
MSCCLPP_DEVICE_INLINE void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId, uint32_t numThreads) {
T reg;
for (size_t i = threadId; i < numElems; i += numThreads) {
// Load to register first.
reg = src[i];
// Then store to destination.
dst[i] = reg;
}
}
} // namespace detail
/// this is a helper for copy function
template <typename T, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void copyHelper(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
int* dstInt = reinterpret_cast<int*>(dst);
int* srcInt = reinterpret_cast<int*>(src);
const uintptr_t dstPtr = reinterpret_cast<uintptr_t>(dst);
const uintptr_t srcPtr = reinterpret_cast<uintptr_t>(src);
const uint64_t numInt = bytes / sizeof(int);
T* dstElem = reinterpret_cast<T*>((dstPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
T* srcElem = reinterpret_cast<T*>((srcPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
uint64_t nFirstInt = (reinterpret_cast<uintptr_t>(dstElem) - dstPtr) / sizeof(int);
if constexpr (CopyRemainder) {
// Copy the remainder integers at the beginning.
detail::copy<int>(dstInt, srcInt, nFirstInt, threadId, numThreads);
}
// Copy elements.
constexpr uint64_t nIntPerElem = sizeof(T) / sizeof(int);
uint64_t nElem = (numInt - nFirstInt) / nIntPerElem;
detail::copy<T>(dstElem, srcElem, nElem, threadId, numThreads);
if constexpr (CopyRemainder && nIntPerElem > 1) {
// Copy the remainder integers at the end.
uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem;
detail::copy<int>(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt,
threadId, numThreads);
}
}
/// Copy aligned data from the source memory to the destination memory.
///
/// This function is a warpper of Element<T>::copy(). Unlike Element<T>::copy(), this function can copy remainder
/// bytes when @p CopyRemainder is true. Still, the 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src.
/// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
if constexpr (Alignment == 4) {
copyHelper<int, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else if constexpr (Alignment == 8) {
copyHelper<long long, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else if constexpr (Alignment == 16) {
copyHelper<longlong2, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else {
static_assert(Alignment == 4 || Alignment == 8 || Alignment == 16, "Unsupported alignment");
}
}
/// Read data from the origin and write packets to the target buffer.
///
/// @param targetPtr The target buffer.
/// @param originPtr The origin buffer.
/// @param originBytes The number of bytes to write to the target buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to write.
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void copyToPackets(void* targetPtr, const void* originPtr, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag);
template <>
MSCCLPP_DEVICE_INLINE void copyToPackets<LL16Packet>(void* targetPtr, const void* originPtr, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
const uint32_t* originBase = reinterpret_cast<const uint32_t*>(originPtr);
LL16Packet* targetBase = reinterpret_cast<LL16Packet*>(targetPtr);
size_t nElem = originBytes / sizeof(uint64_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
LL16Packet* pkt = &targetBase[i];
pkt->write(originBase[2 * i], originBase[2 * i + 1], flag);
}
}
template <>
MSCCLPP_DEVICE_INLINE void copyToPackets<LL8Packet>(void* targetPtr, const void* originPtr, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes
const uint32_t* originBase = reinterpret_cast<const uint32_t*>(originPtr);
LL8Packet* targetBase = reinterpret_cast<LL8Packet*>(targetPtr);
size_t nElem = originBytes / sizeof(uint32_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
LL8Packet* pkt = &targetBase[i];
pkt->write(originBase[i], flag);
}
}
/// Read packets from the target buffer and write retrieved data to the origin.
///
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
/// @param originPtr The origin buffer.
/// @param targetPtr The target buffer.
/// @param originBytes The number of bytes to read from the origin buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to read.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void copyFromPackets(void* originPtr, const void* targetPtr, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag,
int64_t maxSpinCount = -1);
template <>
MSCCLPP_DEVICE_INLINE void copyFromPackets<LL16Packet>(void* originPtr, const void* targetPtr, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag,
int64_t maxSpinCount) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
const LL16Packet* targetBase = reinterpret_cast<const LL16Packet*>(targetPtr);
uint2* originBase = reinterpret_cast<uint2*>(originPtr);
size_t nElem = originBytes / sizeof(uint2);
for (size_t i = threadId; i < nElem; i += numThreads) {
const LL16Packet* pkt = &targetBase[i];
originBase[i] = pkt->read(flag, maxSpinCount);
}
}
template <>
MSCCLPP_DEVICE_INLINE void copyFromPackets<LL8Packet>(void* originPtr, const void* targetPtr, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag,
int64_t maxSpinCount) {
// Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes
const LL8Packet* targetBase = reinterpret_cast<const LL8Packet*>(targetPtr);
uint32_t* originBase = reinterpret_cast<uint32_t*>(originPtr);
size_t nElem = originBytes / sizeof(uint32_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
const LL8Packet* pkt = &targetBase[i];
originBase[i] = pkt->read(flag, maxSpinCount);
}
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
} // namespace mscclpp
#endif // MSCCLPP_COPY_DEVICE_HPP_

View File

@@ -12,25 +12,52 @@
namespace mscclpp {
/// Channel for accessing peer memory directly from GPU threads.
struct MemoryChannel {
private:
/// Memory channel without specifying source/destination memory regions.
struct BaseMemoryChannel {
protected:
std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore_;
RegisteredMemory dst_;
void* src_;
void* getPacketBuffer_;
public:
/// Default constructor.
BaseMemoryChannel() = default;
/// Constructor.
/// @param semaphore The semaphore used to synchronize the communication.
BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore);
BaseMemoryChannel(const BaseMemoryChannel& other) = default;
BaseMemoryChannel& operator=(BaseMemoryChannel& other) = default;
/// Device-side handle for @ref BaseMemoryChannel.
using DeviceHandle = BaseMemoryChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the BaseMemoryChannel is not released when using the returned handle.
///
DeviceHandle deviceHandle() const;
};
/// Channel for accessing peer memory directly from GPU threads.
struct MemoryChannel : public BaseMemoryChannel {
private:
RegisteredMemory dst_;
void* src_;
void* packetBuffer_;
public:
/// Default constructor.
MemoryChannel() = default;
/// Constructor.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param dst Registered memory of the destination.
/// @param src The source memory address.
/// @param getPacketBuffer The optional buffer used for @ref getPackets().
/// @param packetBuffer A buffer used to store packets. @p packetBuffer is optional and if it is nullptr,
/// unpackPacket() and unpackPackets() methods are not available.
MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
void* getPacketBuffer = nullptr);
void* packetBuffer = nullptr);
/// Device-side handle for @ref MemoryChannel.
using DeviceHandle = MemoryChannelDeviceHandle;

View File

@@ -6,238 +6,21 @@
#include "semaphore_device.hpp"
#if defined(MSCCLPP_DEVICE_COMPILE)
#include "packet_device.hpp"
#include "copy_device.hpp"
#endif // defined(MSCCLPP_DEVICE_COMPILE)
namespace mscclpp {
#if defined(MSCCLPP_DEVICE_COMPILE)
namespace Element {
/// Copy aligned elements from the source memory to the destination memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
/// elements.
///
/// @param dst The destination address.
/// @param src The source address.
/// @param numElems The number of elements to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different
/// from the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <typename T>
MSCCLPP_DEVICE_INLINE void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId, uint32_t numThreads) {
T reg;
for (size_t i = threadId; i < numElems; i += numThreads) {
// Load to register first.
reg = src[i];
// Then store to destination.
dst[i] = reg;
}
}
} // namespace Element
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// Device-side handle of a MemoryChannel.
struct MemoryChannelDeviceHandle {
struct BaseMemoryChannelDeviceHandle {
MemoryDevice2DeviceSemaphoreDeviceHandle semaphore_;
void* src_;
void* dst_;
void* getPacketBuffer_;
MSCCLPP_HOST_DEVICE_INLINE BaseMemoryChannelDeviceHandle() = default;
MSCCLPP_HOST_DEVICE_INLINE BaseMemoryChannelDeviceHandle(MemoryDevice2DeviceSemaphoreDeviceHandle semaphore)
: semaphore_(semaphore) {}
#if defined(MSCCLPP_DEVICE_COMPILE)
/// Load a value from the remote memory.
/// @tparam T The type of the value to be loaded.
/// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T).
/// @return The value loaded.
template <typename T>
MSCCLPP_DEVICE_INLINE T read(uint64_t index) {
return *(reinterpret_cast<T*>(dst_) + index);
}
/// Write a value to the remote memory.
/// @tparam T The type of the value to be written.
/// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T).
/// @param v The value to be written.
template <typename T>
MSCCLPP_DEVICE_INLINE void write(uint64_t index, const T& v) {
*(reinterpret_cast<T*>(dst_) + index) = v;
}
/// this is a helper for copy function
template <typename T, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void copy_helper(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
int* dstInt = reinterpret_cast<int*>(dst);
int* srcInt = reinterpret_cast<int*>(src);
const uintptr_t dstPtr = reinterpret_cast<uintptr_t>(dst);
const uintptr_t srcPtr = reinterpret_cast<uintptr_t>(src);
const uint64_t numInt = bytes / sizeof(int);
T* dstElem = reinterpret_cast<T*>((dstPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
T* srcElem = reinterpret_cast<T*>((srcPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T));
uint64_t nFirstInt = (reinterpret_cast<uintptr_t>(dstElem) - dstPtr) / sizeof(int);
if (CopyRemainder) {
// Copy the remainder integers at the beginning.
Element::copy<int>(dstInt, srcInt, nFirstInt, threadId, numThreads);
}
// Copy elements.
constexpr uint64_t nIntPerElem = sizeof(T) / sizeof(int);
uint64_t nElem = (numInt - nFirstInt) / nIntPerElem;
Element::copy<T>(dstElem, srcElem, nElem, threadId, numThreads);
if (CopyRemainder && nIntPerElem > 1) {
// Copy the remainder integers at the end.
uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem;
Element::copy<int>(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt,
threadId, numThreads);
}
}
/// Copy aligned data from the source memory to the destination memory.
///
/// This function is a warpper of Element<T>::copy(). Unlike Element<T>::copy(), this function can copy remainder
/// bytes when @p CopyRemainder is true. Still, the 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src.
/// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
if (Alignment == 4) {
copy_helper<int, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else if (Alignment == 8) {
copy_helper<long long, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else if (Alignment == 16) {
copy_helper<longlong2, CopyRemainder>(dst, src, bytes, threadId, numThreads);
} else {
static_assert(Alignment == 4 || Alignment == 8 || Alignment == 16, "Unsupported alignment");
}
}
/// Copy data from the local memory (origin) to the remote memory (target).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param targetOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param originOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param originBytes Bytes of the origin to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void put(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads) {
copy<Alignment, CopyRemainder>((char*)dst_ + targetOffset, (char*)src_ + originOffset, originBytes, threadId,
numThreads);
}
/// Copy data from the remote memory (target) to the local memory (origin).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param targetOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param originOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param originBytes Bytes of the origin to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void get(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads) {
// Note that `dst` and `src` are swapped for `get()`.
copy<Alignment, CopyRemainder>((char*)src_ + originOffset, (char*)dst_ + targetOffset, originBytes, threadId,
numThreads);
}
/// Copy data from the local memory (origin) to the remote memory (target).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void put(uint64_t offset, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
put<Alignment, CopyRemainder>(offset, offset, bytes, threadId, numThreads);
}
/// Copy data from the remote memory (target) to the local memory (origin).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment.
/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void get(uint64_t offset, uint64_t bytes, uint32_t threadId, uint32_t numThreads) {
get<Alignment, CopyRemainder>(offset, offset, bytes, threadId, numThreads);
}
/// Construct @ref LLPacket from the data in the local memory (origin) and write it on the remote packet buffer
/// (target).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of packets.
///
/// @param targetOffset The offset in bytes of the remote packet buffer.
/// @param originOffset The offset in bytes of the local data.
/// @param originBytes Bytes of the origin to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void putPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
mscclpp::putPackets<PacketType>(dst_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag);
}
/// Retrieve data from @ref LLPacket in the local packet buffer (target) and write it on the local data (origin).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @param targetOffset The offset in bytes of the local packet buffer.
/// @param originOffset The offset in bytes of the local data.
/// @param originBytes Bytes of the origin to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void getPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
mscclpp::getPackets<PacketType>(getPacketBuffer_, targetOffset, src_, originOffset, originBytes, threadId,
numThreads, flag);
}
/// Signal the remote semaphore.
///
/// This function guarantees that all the memory operation before this function is completed before the remote
@@ -252,14 +35,6 @@ struct MemoryChannelDeviceHandle {
///
MSCCLPP_DEVICE_INLINE void relaxedSignal() { semaphore_.relaxedSignal(); }
/// Signal the remote semaphore for copied packets.
///
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
/// completion of copies.
///
MSCCLPP_DEVICE_INLINE void signalPacket() { semaphore_.signalPacket(); }
/// Increase the counter of the local semaphore.
MSCCLPP_DEVICE_INLINE void semaphoreIncrement() { semaphore_.semaphoreIncrement(); }
@@ -280,7 +55,177 @@ struct MemoryChannelDeviceHandle {
/// User requires to call proper fencing before using this function.
///
/// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative.
MSCCLPP_DEVICE_INLINE void relaxedWait() { semaphore_.relaxedWait(); }
MSCCLPP_DEVICE_INLINE void relaxedWait(int64_t maxSpinCount = 10000000) { semaphore_.relaxedWait(maxSpinCount); }
#endif // defined(MSCCLPP_DEVICE_COMPILE)
};
/// Device-side handle of a MemoryChannel.
struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle {
void* dst_;
void* src_;
void* packetBuffer_;
MSCCLPP_HOST_DEVICE_INLINE MemoryChannelDeviceHandle() = default;
MSCCLPP_HOST_DEVICE_INLINE MemoryChannelDeviceHandle(MemoryDevice2DeviceSemaphoreDeviceHandle semaphore, void* dst,
void* src, void* packetBuffer)
: BaseMemoryChannelDeviceHandle(semaphore), dst_(dst), src_(src), packetBuffer_(packetBuffer) {}
#if defined(MSCCLPP_DEVICE_COMPILE)
/// Load a value from the remote memory.
/// @tparam T The type of the value to be loaded.
/// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T).
/// @return The value loaded.
template <typename T>
MSCCLPP_DEVICE_INLINE T read(uint64_t index) {
return *(reinterpret_cast<T*>(dst_) + index);
}
/// Write a value to the remote memory.
/// @tparam T The type of the value to be written.
/// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T).
/// @param v The value to be written.
template <typename T>
MSCCLPP_DEVICE_INLINE void write(uint64_t index, const T& v) {
*(reinterpret_cast<T*>(dst_) + index) = v;
}
/// Copy data from the local memory (origin) to the remote memory (target).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param targetOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param originOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param originBytes Bytes of the origin to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void put(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads) {
copy<Alignment, CopyRemainder>(reinterpret_cast<char*>(dst_) + targetOffset,
reinterpret_cast<char*>(src_) + originOffset, originBytes, threadId, numThreads);
}
/// Wrapper of put() with the same offset for target and origin.
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void put(uint64_t offset, uint64_t originBytes, uint32_t threadId, uint32_t numThreads) {
put<Alignment, CopyRemainder>(offset, offset, originBytes, threadId, numThreads);
}
/// Copy data from the remote memory (origin) to the local memory (target).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16.
/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p
/// Alignment.
/// @param targetOffset The offset in bytes of the local address. Should be a multiple of @p Alignment.
/// @param originOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment.
/// @param originBytes Bytes of the origin to be copied. Should be a multiple of @p Alignment.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void get(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads) {
copy<Alignment, CopyRemainder>(reinterpret_cast<char*>(src_) + targetOffset,
reinterpret_cast<char*>(dst_) + originOffset, originBytes, threadId, numThreads);
}
/// Wrapper of get() with the same offset for target and origin.
template <int Alignment = 16, bool CopyRemainder = true>
MSCCLPP_DEVICE_INLINE void get(uint64_t offset, uint64_t originBytes, uint32_t threadId, uint32_t numThreads) {
get<Alignment, CopyRemainder>(offset, offset, originBytes, threadId, numThreads);
}
/// Copy data from the local memory (origin) to the remote memory (target) using packets.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
/// @param targetOffset The offset in bytes of the remote address.
/// @param originOffset The offset in bytes of the local address.
/// @param originBytes Bytes of the origin to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
/// @param flag The flag to write.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void putPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
"Unsupported packet type");
copyToPackets<PacketType>(reinterpret_cast<char*>(dst_) + targetOffset,
reinterpret_cast<char*>(src_) + originOffset, originBytes, threadId, numThreads, flag);
}
/// Wrapper of putPackets() with the same offset for target and origin.
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void putPackets(uint64_t offset, uint64_t originBytes, uint32_t threadId, uint32_t numThreads,
uint32_t flag) {
putPackets<PacketType>(offset, offset, originBytes, threadId, numThreads, flag);
}
/// Retrieve data from a packet in the local packet buffer.
///
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
/// @param index The index of the packet to be read. The offset in bytes is calculated as index * sizeof(PacketType).
/// @param flag The flag to read.
/// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative.
/// @return The value read from the packet. The type of the value depends on the packet type.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE auto unpackPacket(uint64_t index, uint32_t flag, int64_t maxSpinCount = -1) {
assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
return reinterpret_cast<PacketType*>(packetBuffer_)[index].read(flag, maxSpinCount);
}
/// Retrieve data from packets in the local packet buffer (origin) and write to the local memory (target).
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of data.
///
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
/// @param targetOffset The offset in bytes of the local address.
/// @param originOffset The offset in bytes of the local packet buffer.
/// @param originBytes Bytes of the origin to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different from
/// the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
/// @param flag The flag to write.
/// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void unpackPackets(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag,
int64_t maxSpinCount = -1) {
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
"Unsupported packet type");
assert_device(packetBuffer_ != nullptr, "Packet buffer is null");
copyFromPackets<PacketType>(reinterpret_cast<char*>(src_) + targetOffset,
reinterpret_cast<char*>(packetBuffer_) + originOffset, originBytes, threadId,
numThreads, flag, maxSpinCount);
}
/// Wrapper of unpackPackets() with the same offset for target and origin.
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void unpackPackets(uint64_t offset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag, int64_t maxSpinCount = -1) {
unpackPackets<PacketType>(offset, offset, originBytes, threadId, numThreads, flag, maxSpinCount);
}
template <typename PacketType = LL16Packet>
[[deprecated("Use unpackPackets() instead.")]] MSCCLPP_DEVICE_INLINE void getPackets(
uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes, uint32_t threadId, uint32_t numThreads,
uint32_t flag) {
unpackPackets<PacketType>(targetOffset, originOffset, originBytes, threadId, numThreads, flag, 100000000);
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
};

View File

@@ -148,156 +148,6 @@ union alignas(8) LL8Packet {
using LLPacket = LL16Packet;
#if defined(MSCCLPP_DEVICE_COMPILE)
/// Read data from the origin and write LL16Packets to the target buffer.
///
/// @param targetPtr The target buffer.
/// @param targetOffset The offset in the target buffer.
/// @param originPtr The origin buffer.
/// @param originOffset The offset in the origin buffer.
/// @param originBytes The number of bytes to write to the target buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to write.
///
MSCCLPP_DEVICE_INLINE void putLL16Packets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
const uint32_t* originBase = (const uint32_t*)((const char*)originPtr + originOffset);
LL16Packet* targetBase = (LL16Packet*)((char*)targetPtr + targetOffset);
size_t nElem = originBytes / sizeof(uint64_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
LL16Packet* pkt = &targetBase[i];
pkt->write(originBase[2 * i], originBase[2 * i + 1], flag);
}
}
/// Read LL16Packets from the target buffer and write retrieved data to the origin.
///
/// @param targetPtr The target buffer.
/// @param targetOffset The offset in the target buffer.
/// @param originPtr The origin buffer.
/// @param originOffset The offset in the origin buffer.
/// @param originBytes The number of bytes to write to the target buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to write.
///
MSCCLPP_DEVICE_INLINE void getLL16Packets(const void* targetPtr, uint64_t targetOffset, void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
const LL16Packet* targetBase = (const LL16Packet*)((const char*)targetPtr + targetOffset);
uint2* originBase = (uint2*)((char*)originPtr + originOffset);
size_t nElem = originBytes / sizeof(uint2);
for (size_t i = threadId; i < nElem; i += numThreads) {
const LL16Packet* pkt = &targetBase[i];
originBase[i] = pkt->read(flag);
}
}
/// Read data from the origin and write LL8Packets to the target buffer.
///
/// @param targetPtr The target buffer.
/// @param targetOffset The offset in the target buffer.
/// @param originPtr The origin buffer.
/// @param originOffset The offset in the origin buffer.
/// @param originBytes The number of bytes to write to the target buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to write.
///
MSCCLPP_DEVICE_INLINE void putLL8Packets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes
const uint32_t* originBase = (const uint32_t*)((const char*)originPtr + originOffset);
LL8Packet* targetBase = (LL8Packet*)((char*)targetPtr + targetOffset);
size_t nElem = originBytes / sizeof(uint32_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
LL8Packet* pkt = &targetBase[i];
pkt->write(originBase[i], flag);
}
}
/// Read LL8Packets from the target buffer and write retrieved data to the origin.
///
/// @param targetPtr The target buffer.
/// @param targetOffset The offset in the target buffer.
/// @param originPtr The origin buffer.
/// @param originOffset The offset in the origin buffer.
/// @param originBytes The number of bytes to write to the target buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to write.
///
MSCCLPP_DEVICE_INLINE void getLL8Packets(const void* targetPtr, uint64_t targetOffset, void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 4 bytes & size should be a multiple of 4 bytes
const LL8Packet* targetBase = (const LL8Packet*)((const char*)targetPtr + targetOffset);
uint32_t* originBase = (uint32_t*)((char*)originPtr + originOffset);
size_t nElem = originBytes / sizeof(uint32_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
const LL8Packet* pkt = &targetBase[i];
originBase[i] = pkt->read(flag);
}
}
/// Read data from the origin and write packets to the target buffer.
///
/// @param targetPtr The target buffer.
/// @param targetOffset The offset in the target buffer.
/// @param originPtr The origin buffer.
/// @param originOffset The offset in the origin buffer.
/// @param originBytes The number of bytes to write to the target buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to write.
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void putPackets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
if constexpr (std::is_same<PacketType, LL16Packet>::value) {
putLL16Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
} else if constexpr (std::is_same<PacketType, LL8Packet>::value) {
putLL8Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
} else {
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
"Unsupported packet type");
}
}
/// Read packets from the target buffer and write retrieved data to the origin.
///
/// @param targetPtr The target buffer.
/// @param targetOffset The offset in the target buffer.
/// @param originPtr The origin buffer.
/// @param originOffset The offset in the origin buffer.
/// @param originBytes The number of bytes to read from the origin buffer.
/// @param threadId The thread ID. The thread ID should be less than @p numThreads.
/// @param numThreads The number of threads that call this function.
/// @param flag The flag to read.
/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet.
///
template <typename PacketType = LL16Packet>
MSCCLPP_DEVICE_INLINE void getPackets(const void* targetPtr, uint64_t targetOffset, void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
if constexpr (std::is_same<PacketType, LL16Packet>::value) {
getLL16Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
} else if constexpr (std::is_same<PacketType, LL8Packet>::value) {
getLL8Packets(targetPtr, targetOffset, originPtr, originOffset, originBytes, threadId, numThreads, flag);
} else {
static_assert(std::is_same<PacketType, LL16Packet>::value || std::is_same<PacketType, LL8Packet>::value,
"Unsupported packet type");
}
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
}; // namespace mscclpp
} // namespace mscclpp
#endif // MSCCLPP_PACKET_DEVICE_HPP_

View File

@@ -4,24 +4,10 @@
#ifndef MSCCLPP_POLL_DEVICE_HPP_
#define MSCCLPP_POLL_DEVICE_HPP_
#include "device.hpp"
#include "assert_device.hpp"
#if defined(MSCCLPP_DEVICE_COMPILE)
#include <cstdint>
#if !defined(DEBUG_BUILD)
#define __assert_fail(__assertion, __file, __line, __function) ;
#else // defined(DEBUG_BUILD)
#if defined(MSCCLPP_DEVICE_HIP)
extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
const char *__function);
#else // !defined(MSCCLPP_DEVICE_HIP)
extern "C" __host__ __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
const char *__function) __THROW;
#endif // !defined(MSCCLPP_DEVICE_HIP)
#endif // !defined(DEBUG_BUILD)
// If a spin is stuck, print a warning and keep spinning.
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \

View File

@@ -95,17 +95,6 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle {
atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderRelaxed);
}
/// Signal the remote device for copied packets.
///
/// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is
/// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the
/// completion of copies.
///
MSCCLPP_DEVICE_INLINE void signalPacket() {
semaphoreIncrement();
*remoteInboundSemaphoreId = semaphoreGetLocal();
}
/// Increase the counter of the local semaphore.
MSCCLPP_DEVICE_INLINE void semaphoreIncrement() { *outboundSemaphoreId += 1; }

View File

@@ -26,9 +26,9 @@ void register_memory_channel(nb::module_& m) {
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
.def_rw("getPacketBuffer_", &MemoryChannel::DeviceHandle::getPacketBuffer_)
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
.def_rw("packetBuffer_", &MemoryChannel::DeviceHandle::packetBuffer_)
.def_prop_ro("raw", [](const MemoryChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});

View File

@@ -16,7 +16,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
if (bid < nranks && bid != my_rank) {
if (use_packet) {
channels[bid].putPackets(2 * my_offset, my_offset, size_per_rank, tid, blockDim.x, flag);
channels[bid].getPackets(2 * my_nghr_offset, my_nghr_offset, size_per_rank, tid, blockDim.x, flag);
channels[bid].unpackPackets(2 * my_nghr_offset, my_nghr_offset, size_per_rank, tid, blockDim.x, flag);
} else {
channels[bid].put(my_offset, my_offset, size_per_rank, tid, blockDim.x);
__syncthreads();

View File

@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <mscclpp/packet_device.hpp>
#include <mscclpp/copy_device.hpp>
#include <mscclpp/port_channel_device.hpp>
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
@@ -18,14 +18,14 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
__syncthreads();
int flag = 123;
if (use_packet) {
mscclpp::putPackets(scratch, 2 * my_offset, data, my_offset, size_per_rank, tid, nthreads, flag);
mscclpp::copyToPackets((char*)scratch + 2 * my_offset, (char*)data + my_offset, size_per_rank, tid, nthreads, flag);
__syncthreads();
if (tid < nranks && tid != my_rank) {
channels[tid].put(2 * my_offset, 2 * my_offset, 2 * size_per_rank);
}
if (my_nghr != my_rank && my_nghr < nranks)
mscclpp::getPackets(scratch, 2 * my_nghr_offset, data, my_nghr_offset, size_per_rank, tid % nthreads_per_rank,
nthreads_per_rank, flag);
mscclpp::copyFromPackets((char*)data + my_nghr_offset, (char*)scratch + 2 * my_nghr_offset, size_per_rank,
tid % nthreads_per_rank, nthreads_per_rank, flag);
} else {
if (tid < nranks && tid != my_rank) {
channels[tid].putWithSignalAndFlush(my_offset, my_offset, size_per_rank);

View File

@@ -425,7 +425,8 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t
uint32_t srcOffset, size_t size, uint32_t flag) {
const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : dstSize >> 1;
dstOffset = dstOffset * 2 + outputScratchBaseOffset;
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
mscclpp::copyToPackets<PacketType>((char*)dst + dstOffset, (char*)src + srcOffset, size, threadIdx.x, blockDim.x,
flag);
}
template <typename T, bool SendToRemote = true>
@@ -477,7 +478,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
MSCCLPP_DEVICE_INLINE void handleCopy(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset, size_t size) {
char* srcData = (char*)src + srcOffset;
char* dstData = (char*)dst + dstOffset;
Element::copy(dstData, srcData, size, threadIdx.x, blockDim.x);
mscclpp::copy(dstData, srcData, size, threadIdx.x, blockDim.x);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900

View File

@@ -8,19 +8,23 @@
namespace mscclpp {
MSCCLPP_API_CPP BaseMemoryChannel::BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore)
: semaphore_(semaphore) {}
MSCCLPP_API_CPP MemoryChannel::MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, void* src, void* getPacketBuffer)
: semaphore_(semaphore), dst_(dst), src_(src), getPacketBuffer_(getPacketBuffer) {
RegisteredMemory dst, void* src, void* packetBuffer)
: BaseMemoryChannel(semaphore), dst_(dst), src_(src), packetBuffer_(packetBuffer) {
if (!dst.transports().has(Transport::CudaIpc)) {
throw Error("MemoryChannel: dst must be registered with CudaIpc", ErrorCode::InvalidUsage);
}
}
MSCCLPP_API_CPP BaseMemoryChannel::DeviceHandle BaseMemoryChannel::deviceHandle() const {
return BaseMemoryChannel::DeviceHandle(semaphore_->deviceHandle());
}
MSCCLPP_API_CPP MemoryChannel::DeviceHandle MemoryChannel::deviceHandle() const {
return DeviceHandle{.semaphore_ = semaphore_->deviceHandle(),
.src_ = src_,
.dst_ = dst_.data(),
.getPacketBuffer_ = getPacketBuffer_};
return MemoryChannel::DeviceHandle(semaphore_->deviceHandle(), dst_.data(), src_, packetBuffer_);
}
} // namespace mscclpp

View File

@@ -316,8 +316,8 @@ __global__ void kernelMemLL8PacketPingPong(int* buff, int rank, int nElem, int*
// __syncthreads();
memChan.putPackets<mscclpp::LL8Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
} else {
memChan.getPackets<mscclpp::LL8Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
// If each thread reads 8 bytes at once, we don't need a barrier after getPackets().
memChan.unpackPackets<mscclpp::LL8Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
// If each thread reads 8 bytes at once, we don't need a barrier after unpackPackets().
// __syncthreads();
for (int j = threadIdx.x; j < nElem; j += blockDim.x) {
if (sendBuff[j] != getOffset + i + j) {
@@ -353,8 +353,8 @@ __global__ void kernelMemLL16PacketPingPong(int* buff, int rank, int nElem, int*
// __syncthreads();
memChan.putPackets<mscclpp::LL16Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
} else {
memChan.getPackets<mscclpp::LL16Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
// If each thread reads 8 bytes at once, we don't need a barrier after getPackets().
memChan.unpackPackets<mscclpp::LL16Packet>(0, 0, nElem * sizeof(int), threadIdx.x, blockDim.x, flag);
// If each thread reads 8 bytes at once, we don't need a barrier after unpackPackets().
// __syncthreads();
for (int j = threadIdx.x; j < nElem / 2; j += blockDim.x) {
if (sendBuff[2 * j] != getOffset + i + 2 * j) {

View File

@@ -294,14 +294,14 @@ __global__ void kernelProxyLLPingPong(int* buff, mscclpp::LLPacket* putPktBuf, m
// rank=1: 1, 0, 1, 0, ...
if ((rank ^ (i & 1)) == 0) {
if (CheckCorrectness) {
// If each thread writes 8 bytes at once, we don't need a barrier before putPackets().
// If each thread writes 8 bytes at once, we don't need a barrier before copyToPackets().
for (int j = threadId; j < nPkt; j += numThreads) {
buffPtr[2 * j] = putOffset + i + 2 * j;
buffPtr[2 * j + 1] = putOffset + i + 2 * j + 1;
}
// __syncthreads();
}
mscclpp::putPackets(putPktBuf, 0, buff, 0, nElem * sizeof(int), threadId, numThreads, flag);
mscclpp::copyToPackets(putPktBuf, buff, nElem * sizeof(int), threadId, numThreads, flag);
gChannelOneToOneTestPortChansSyncer.sync(gridDim.x);
if (threadId == 0) {
// Send data from the local putPacketBuffer to the remote getPacketBuffer
@@ -313,9 +313,9 @@ __global__ void kernelProxyLLPingPong(int* buff, mscclpp::LLPacket* putPktBuf, m
flusher = 0;
}
} else {
mscclpp::getPackets(getPktBuf, 0, buff, 0, nElem * sizeof(int), threadId, numThreads, flag);
mscclpp::copyFromPackets(buff, getPktBuf, nElem * sizeof(int), threadId, numThreads, flag);
if (CheckCorrectness) {
// If each thread reads 8 bytes at once, we don't need a barrier after getPackets().
// If each thread reads 8 bytes at once, we don't need a barrier after copyFromPackets().
// __syncthreads();
for (int j = threadId; j < nPkt; j += numThreads) {
if (buffPtr[2 * j] != getOffset + i + 2 * j) {

View File

@@ -451,7 +451,7 @@ __global__ void __launch_bounds__(1024, 1)
const size_t peerIdx = wid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp;
memChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
memChans[peerIdx].unpackPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
}
for (size_t i = 1; i < nLoop; ++i) {
@@ -466,7 +466,7 @@ __global__ void __launch_bounds__(1024, 1)
const size_t peerIdx = gWid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp;
memChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
memChans[peerIdx].unpackPackets(scratchOffset + offset * 2, offset, unitBytesPerWarp, lid, WARP_SIZE, flag);
}
if (bytes % unitBytes > 0) {
@@ -491,7 +491,7 @@ __global__ void __launch_bounds__(1024, 1)
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
: unitBytesPerWarp;
if (remainBytes > 0) {
memChans[peerIdx].getPackets(scratchOffset + offset * 2, offset, remainBytes, lid, WARP_SIZE, flag);
memChans[peerIdx].unpackPackets(scratchOffset + offset * 2, offset, remainBytes, lid, WARP_SIZE, flag);
}
}

View File

@@ -859,8 +859,8 @@ __global__ void __launch_bounds__(1024)
int2* src = (int2*)buff;
int2* res = (int2*)result;
// double buffering
size_t scratchOffset = (flag & 1) ? 0 : nPkts * max(numPeersPerNode, 1) * sizeof(mscclpp::LLPacket);
mscclpp::LLPacket* scratchPtr = (mscclpp::LLPacket*)((char*)scratch + scratchOffset);
size_t scratchBaseIndex = (flag & 1) ? 0 : nPkts * max(numPeersPerNode, 1);
size_t scratchOffset = scratchBaseIndex * sizeof(mscclpp::LLPacket);
size_t pktBufOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
mscclpp::LLPacket* getPktPtr = (mscclpp::LLPacket*)((char*)getPktBuf + pktBufOffset);
mscclpp::LLPacket* putPktPtr = (mscclpp::LLPacket*)((char*)putPktBuf + pktBufOffset);
@@ -887,18 +887,15 @@ __global__ void __launch_bounds__(1024)
int x = 0;
int y = 0;
for (int peerIdx = 0; peerIdx < numPeersPerNode / 2; ++peerIdx) {
mscclpp::LLPacket* pkt0 = scratchPtr + 2 * peerIdx * nPkts;
mscclpp::LLPacket* pkt1 = scratchPtr + (2 * peerIdx + 1) * nPkts;
uint2 data0 = pkt0[idx].read(flag);
uint2 data1 = pkt1[idx].read(flag);
uint2 data0 = memChan.unpackPacket(scratchBaseIndex + 2 * peerIdx * nPkts + idx, flag);
uint2 data1 = memChan.unpackPacket(scratchBaseIndex + (2 * peerIdx + 1) * nPkts + idx, flag);
x += (int)data0.x;
y += (int)data0.y;
x += (int)data1.x;
y += (int)data1.y;
}
if (numPeersPerNode & 1) {
mscclpp::LLPacket* pkt = scratchPtr + (numPeersPerNode - 1) * nPkts;
uint2 data = pkt[idx].read(flag);
uint2 data = memChan.unpackPacket(scratchBaseIndex + (numPeersPerNode - 1) * nPkts + idx, flag);
x += (int)data.x;
y += (int)data.y;
}
@@ -988,11 +985,10 @@ __global__ void __launch_bounds__(1024)
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchBaseIndex = (flag & 1) ? 0 : nPkts;
size_t scratchBaseOffset = scratchBaseIndex * sizeof(mscclpp::LLPacket);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t scratchResultIndex = (flag & 1) ? 2 * nPkts : 3 * nPkts;
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
@@ -1005,8 +1001,8 @@ __global__ void __launch_bounds__(1024)
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
uint2 val =
constMemOutOfPlaceChans[peerIdx].unpackPacket(scratchBaseIndex + remoteRank * nPktsPerRank + idx, flag);
data.x += val.x;
data.y += val.y;
}
@@ -1019,17 +1015,16 @@ __global__ void __launch_bounds__(1024)
packet.flag1 = flag;
packet.data2 = data.y;
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
size_t offset = scratchResultIndex + (idx + rank * nPktsPerRank);
for (int index = 0; index < nPeers; index++) {
constMemOutOfPlaceChans[index].write(offset, packet);
}
}
// step 3: get data result from scratch buffer
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
uint2 data = constMemOutOfPlaceChans[peerIdx].unpackPacket(scratchResultIndex + dstOffset + idx, flag);
result[idx].x = data.x;
result[idx].y = data.y;
}
@@ -1054,11 +1049,9 @@ __global__ void __launch_bounds__(1024)
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LL8Packet);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LL8Packet) : 3 * nPkts * sizeof(mscclpp::LL8Packet);
size_t scratchBaseIndex = (flag & 1) ? 0 : nPkts;
size_t scratchOffset = (scratchBaseIndex + rank * nPktsPerRank) * sizeof(mscclpp::LL8Packet);
size_t scratchResultIndex = (flag & 1) ? 2 * nPkts : 3 * nPkts;
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int));
uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
@@ -1071,8 +1064,8 @@ __global__ void __launch_bounds__(1024)
uint32_t data = 0;
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank;
uint32_t val = dstPkt[idx].read(flag);
uint32_t val = constMemOutOfPlaceChans[peerIdx].unpackPacket<mscclpp::LL8Packet>(
scratchBaseIndex + remoteRank * nPktsPerRank + idx, flag);
data += val;
}
data += src[idx];
@@ -1081,17 +1074,17 @@ __global__ void __launch_bounds__(1024)
mscclpp::LL8Packet packet;
packet.data = data;
packet.flag = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LL8Packet) + (idx + rank * nPktsPerRank);
size_t offset = scratchResultIndex + (idx + rank * nPktsPerRank);
for (int index = 0; index < nPeers; index++) {
constMemOutOfPlaceChans[index].write(offset, packet);
}
}
// step 3: get data result from scratch buffer
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint32_t* result = (uint32_t*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
uint32_t data = dstPkt[idx + dstOffset].read(flag);
uint32_t data =
constMemOutOfPlaceChans[peerIdx].unpackPacket<mscclpp::LL8Packet>(scratchResultIndex + dstOffset + idx, flag);
result[idx] = data;
}
if (threadIdx.x == 0 && blockIdx.x == 0) {