mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Merge ChannelTrigger with ProxyTrigger (#601)
This commit is contained in:
@@ -15,11 +15,82 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
using TriggerType = uint64_t;
|
||||
constexpr TriggerType TriggerData = 0x1; // Trigger a data transfer.
|
||||
constexpr TriggerType TriggerFlag = 0x2; // Trigger a signaling.
|
||||
constexpr TriggerType TriggerSync = 0x4; // Trigger a flush.
|
||||
|
||||
constexpr unsigned int TriggerBitsSize = 32;
|
||||
constexpr unsigned int TriggerBitsOffset = 32;
|
||||
constexpr unsigned int TriggerBitsMemoryId = 9;
|
||||
constexpr unsigned int TriggerBitsType = 3;
|
||||
constexpr unsigned int TriggerBitsSemaphoreId = 10;
|
||||
constexpr unsigned int TriggerBitsFifoReserved = 1;
|
||||
|
||||
/// Pair of 64-bit unsigned integers used as a trigger for the proxy.
|
||||
/// Used as a work element in the concurrent FIFO.
|
||||
/// Most significant bit of snd is reserved.
|
||||
struct alignas(16) ProxyTrigger {
|
||||
uint64_t fst, snd;
|
||||
union alignas(16) ProxyTrigger {
|
||||
struct {
|
||||
uint64_t fst;
|
||||
uint64_t snd;
|
||||
};
|
||||
// The summation of number of bits must be 128 or less.
|
||||
struct {
|
||||
// First 64 bits: value[0]
|
||||
uint64_t size : TriggerBitsSize;
|
||||
uint64_t srcOffset : TriggerBitsOffset;
|
||||
uint64_t : (64 - TriggerBitsSize - TriggerBitsOffset); // ensure 64-bit alignment
|
||||
// Second 64 bits: value[1]
|
||||
uint64_t dstOffset : TriggerBitsOffset;
|
||||
uint64_t srcMemoryId : TriggerBitsMemoryId;
|
||||
uint64_t dstMemoryId : TriggerBitsMemoryId;
|
||||
uint64_t type : TriggerBitsType;
|
||||
uint64_t semaphoreId : TriggerBitsSemaphoreId;
|
||||
uint64_t : (64 - TriggerBitsOffset - TriggerBitsMemoryId - TriggerBitsMemoryId - TriggerBitsType -
|
||||
TriggerBitsSemaphoreId - TriggerBitsFifoReserved); // ensure 64-bit alignment
|
||||
uint64_t reserved : TriggerBitsFifoReserved;
|
||||
} fields;
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
/// Default constructor.
|
||||
MSCCLPP_INLINE ProxyTrigger() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param type The type of the trigger.
|
||||
/// @param dstId The destination ID of memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param srcId The source ID of memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param bytes The bytes of the transfer.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
MSCCLPP_DEVICE_INLINE ProxyTrigger(TriggerType type, uint32_t dstId, uint64_t dstOffset, uint32_t srcId,
|
||||
uint64_t srcOffset, uint64_t bytes, uint32_t semaphoreId) {
|
||||
MSCCLPP_ASSERT_DEVICE(type < (1ULL << TriggerBitsType), "type is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(dstId < (1ULL << TriggerBitsMemoryId), "dstId is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(dstOffset < (1ULL << TriggerBitsOffset), "dstOffset is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(srcId < (1ULL << TriggerBitsMemoryId), "srcId is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(srcOffset < (1ULL << TriggerBitsOffset), "srcOffset is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(bytes != 0, "bytes must not be zero");
|
||||
MSCCLPP_ASSERT_DEVICE(bytes < (1ULL << TriggerBitsSize), "bytes is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(semaphoreId < (1ULL << TriggerBitsSemaphoreId), "semaphoreId is too large");
|
||||
constexpr uint64_t maskSize = (1ULL << TriggerBitsSize) - 1;
|
||||
constexpr uint64_t maskSrcOffset = (1ULL << TriggerBitsOffset) - 1;
|
||||
constexpr uint64_t maskDstOffset = (1ULL << TriggerBitsOffset) - 1;
|
||||
constexpr uint64_t maskSrcMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
|
||||
constexpr uint64_t maskDstMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
|
||||
constexpr uint64_t maskType = (1ULL << TriggerBitsType) - 1;
|
||||
constexpr uint64_t maskSemaphoreId = (1ULL << TriggerBitsSemaphoreId) - 1;
|
||||
fst = (((srcOffset & maskSrcOffset) << TriggerBitsSize) + (bytes & maskSize));
|
||||
snd = (((((((((semaphoreId & maskSemaphoreId) << TriggerBitsType) + ((uint64_t)type & maskType))
|
||||
<< TriggerBitsMemoryId) +
|
||||
(dstId & maskDstMemoryId))
|
||||
<< TriggerBitsMemoryId) +
|
||||
(srcId & maskSrcMemoryId))
|
||||
<< TriggerBitsOffset) +
|
||||
(dstOffset & maskDstOffset));
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
};
|
||||
|
||||
/// Concurrent FIFO where multiple device threads (the number of threads should not exceed the FIFO size) to push
|
||||
@@ -32,7 +103,7 @@ struct FifoDeviceHandle {
|
||||
/// @param trigger Trigger to push.
|
||||
/// @param maxSpinCount Max spin count before assert. Never assert if negative.
|
||||
/// @return Previous head of the FIFO where the trigger was pushed.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t push(ProxyTrigger trigger, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
|
||||
MSCCLPP_DEVICE_INLINE uint64_t push(ProxyTrigger trigger, int64_t maxSpinCount = 1000000) {
|
||||
uint64_t prevHead = atomicFetchAdd<uint64_t, scopeDevice>(head, 1, memoryOrderRelaxed);
|
||||
|
||||
// Flip the last bit for safe polling; host will revert.
|
||||
|
||||
@@ -17,82 +17,6 @@ using SemaphoreId = uint32_t;
|
||||
/// actual.
|
||||
using MemoryId = uint32_t;
|
||||
|
||||
using TriggerType = uint64_t;
|
||||
constexpr TriggerType TriggerData = 0x1; // Trigger a data transfer.
|
||||
constexpr TriggerType TriggerFlag = 0x2; // Trigger a signaling.
|
||||
constexpr TriggerType TriggerSync = 0x4; // Trigger a flush.
|
||||
|
||||
constexpr unsigned int TriggerBitsSize = 32;
|
||||
constexpr unsigned int TriggerBitsOffset = 32;
|
||||
constexpr unsigned int TriggerBitsMemoryId = 9;
|
||||
constexpr unsigned int TriggerBitsType = 3;
|
||||
constexpr unsigned int TriggerBitsSemaphoreId = 10;
|
||||
constexpr unsigned int TriggerBitsFifoReserved = 1;
|
||||
|
||||
/// Basic structure of each work element in the FIFO.
|
||||
union ChannelTrigger {
|
||||
ProxyTrigger value;
|
||||
// The summation of number of bits must be 128 or less.
|
||||
struct {
|
||||
// First 64 bits: value[0]
|
||||
uint64_t size : TriggerBitsSize;
|
||||
uint64_t srcOffset : TriggerBitsOffset;
|
||||
uint64_t : (64 - TriggerBitsSize - TriggerBitsOffset); // ensure 64-bit alignment
|
||||
// Second 64 bits: value[1]
|
||||
uint64_t dstOffset : TriggerBitsOffset;
|
||||
uint64_t srcMemoryId : TriggerBitsMemoryId;
|
||||
uint64_t dstMemoryId : TriggerBitsMemoryId;
|
||||
uint64_t type : TriggerBitsType;
|
||||
uint64_t semaphoreId : TriggerBitsSemaphoreId;
|
||||
uint64_t : (64 - TriggerBitsOffset - TriggerBitsMemoryId - TriggerBitsMemoryId - TriggerBitsType -
|
||||
TriggerBitsSemaphoreId - TriggerBitsFifoReserved); // ensure 64-bit alignment
|
||||
uint64_t reserved : TriggerBitsFifoReserved;
|
||||
} fields;
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
/// Default constructor.
|
||||
MSCCLPP_INLINE ChannelTrigger() = default;
|
||||
|
||||
/// Copy constructor.
|
||||
MSCCLPP_DEVICE_INLINE ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
|
||||
/// Constructor.
|
||||
/// @param type The type of the trigger.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param bytes The bytes of the transfer.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
MSCCLPP_DEVICE_INLINE ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t bytes, int semaphoreId) {
|
||||
MSCCLPP_ASSERT_DEVICE(type < (1ULL << TriggerBitsType), "type is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(dst < (1ULL << TriggerBitsMemoryId), "dst is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(dstOffset < (1ULL << TriggerBitsOffset), "dstOffset is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(src < (1ULL << TriggerBitsMemoryId), "src is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(srcOffset < (1ULL << TriggerBitsOffset), "srcOffset is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(bytes != 0, "bytes must not be zero");
|
||||
MSCCLPP_ASSERT_DEVICE(bytes < (1ULL << TriggerBitsSize), "bytes is too large");
|
||||
MSCCLPP_ASSERT_DEVICE(semaphoreId < (1ULL << TriggerBitsSemaphoreId), "semaphoreId is too large");
|
||||
constexpr uint64_t maskSize = (1ULL << TriggerBitsSize) - 1;
|
||||
constexpr uint64_t maskSrcOffset = (1ULL << TriggerBitsOffset) - 1;
|
||||
constexpr uint64_t maskDstOffset = (1ULL << TriggerBitsOffset) - 1;
|
||||
constexpr uint64_t maskSrcMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
|
||||
constexpr uint64_t maskDstMemoryId = (1ULL << TriggerBitsMemoryId) - 1;
|
||||
constexpr uint64_t maskType = (1ULL << TriggerBitsType) - 1;
|
||||
constexpr uint64_t maskSemaphoreId = (1ULL << TriggerBitsSemaphoreId) - 1;
|
||||
value.fst = (((srcOffset & maskSrcOffset) << TriggerBitsSize) + (bytes & maskSize));
|
||||
value.snd = (((((((((semaphoreId & maskSemaphoreId) << TriggerBitsType) + ((uint64_t)type & maskType))
|
||||
<< TriggerBitsMemoryId) +
|
||||
(dst & maskDstMemoryId))
|
||||
<< TriggerBitsMemoryId) +
|
||||
(src & maskSrcMemoryId))
|
||||
<< TriggerBitsOffset) +
|
||||
(dstOffset & maskDstOffset));
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
};
|
||||
|
||||
struct BasePortChannelDeviceHandle {
|
||||
SemaphoreId semaphoreId_;
|
||||
|
||||
@@ -111,77 +35,77 @@ struct BasePortChannelDeviceHandle {
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
/// Push a TriggerData to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstId The ID of destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcId The ID of source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
MSCCLPP_DEVICE_INLINE void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) {
|
||||
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
|
||||
MSCCLPP_DEVICE_INLINE void put(MemoryId dstId, uint64_t dstOffset, MemoryId srcId, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
fifo_.push({TriggerData, dstId, dstOffset, srcId, srcOffset, size, semaphoreId_});
|
||||
}
|
||||
|
||||
/// Push a TriggerData to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param dstId The ID of destination memory region.
|
||||
/// @param srcId The ID of source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
MSCCLPP_DEVICE_INLINE void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
put(dst, offset, src, offset, size);
|
||||
MSCCLPP_DEVICE_INLINE void put(MemoryId dstId, MemoryId srcId, uint64_t offset, uint64_t size) {
|
||||
put(dstId, offset, srcId, offset, size);
|
||||
}
|
||||
|
||||
/// Push a TriggerFlag to the FIFO.
|
||||
MSCCLPP_DEVICE_INLINE void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); }
|
||||
MSCCLPP_DEVICE_INLINE void signal() { fifo_.push({TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_}); }
|
||||
|
||||
/// Push a TriggerData and a TriggerFlag at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstId The ID of destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcId The ID of source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dstId, uint64_t dstOffset, MemoryId srcId, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
|
||||
fifo_.push({TriggerData | TriggerFlag, dstId, dstOffset, srcId, srcOffset, size, semaphoreId_});
|
||||
}
|
||||
|
||||
/// Push a TriggerData and a TriggerFlag at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param dstId The ID of destination memory region.
|
||||
/// @param srcId The ID of source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
|
||||
putWithSignal(dst, offset, src, offset, size);
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignal(MemoryId dstId, MemoryId srcId, uint64_t offset, uint64_t size) {
|
||||
putWithSignal(dstId, offset, srcId, offset, size);
|
||||
}
|
||||
|
||||
/// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param dstId The ID of destination memory region.
|
||||
/// @param dstOffset The offset into the destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param srcId The ID of source memory region.
|
||||
/// @param srcOffset The offset into the source memory region.
|
||||
/// @param size The size of the transfer.
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size, int64_t maxSpinCount = 1000000) {
|
||||
uint64_t curFifoHead = fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_)
|
||||
.value);
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dstId, uint64_t dstOffset, MemoryId srcId,
|
||||
uint64_t srcOffset, uint64_t size, int64_t maxSpinCount = 1000000) {
|
||||
uint64_t curFifoHead =
|
||||
fifo_.push({TriggerData | TriggerFlag | TriggerSync, dstId, dstOffset, srcId, srcOffset, size, semaphoreId_});
|
||||
fifo_.sync(curFifoHead, maxSpinCount);
|
||||
}
|
||||
|
||||
/// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
/// @param dstId The ID of destination memory region.
|
||||
/// @param srcId The ID of source memory region.
|
||||
/// @param offset The common offset into the destination and source memory regions.
|
||||
/// @param size The size of the transfer.
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size,
|
||||
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(MemoryId dstId, MemoryId srcId, uint64_t offset, uint64_t size,
|
||||
int64_t maxSpinCount = 1000000) {
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size, maxSpinCount);
|
||||
putWithSignalAndFlush(dstId, offset, srcId, offset, size, maxSpinCount);
|
||||
}
|
||||
|
||||
/// Push a TriggerSync to the FIFO.
|
||||
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
|
||||
MSCCLPP_DEVICE_INLINE void flush(int64_t maxSpinCount = 1000000) {
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value);
|
||||
uint64_t curFifoHead = fifo_.push({TriggerSync, 0, 0, 0, 0, 1, semaphoreId_});
|
||||
fifo_.sync(curFifoHead, maxSpinCount);
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger");
|
||||
|
||||
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
|
||||
.def_rw("triggers", &FifoDeviceHandle::triggers)
|
||||
|
||||
@@ -78,27 +78,25 @@ MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_->start(); }
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
|
||||
|
||||
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
|
||||
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
|
||||
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.semaphoreId];
|
||||
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
|
||||
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger.fields.semaphoreId];
|
||||
|
||||
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();
|
||||
auto& numRequests = inflightRequests_[semaphore->connection()];
|
||||
|
||||
if (trigger->fields.type & TriggerData) {
|
||||
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
|
||||
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
|
||||
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
|
||||
trigger->fields.size);
|
||||
if (trigger.fields.type & TriggerData) {
|
||||
RegisteredMemory& dst = memories_[trigger.fields.dstMemoryId];
|
||||
RegisteredMemory& src = memories_[trigger.fields.srcMemoryId];
|
||||
semaphore->connection()->write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
|
||||
numRequests++;
|
||||
}
|
||||
|
||||
if (trigger->fields.type & TriggerFlag) {
|
||||
if (trigger.fields.type & TriggerFlag) {
|
||||
semaphore->signal();
|
||||
numRequests++;
|
||||
}
|
||||
|
||||
if (((trigger->fields.type & TriggerSync) && numRequests > 0) ||
|
||||
if (((trigger.fields.type & TriggerSync) && numRequests > 0) ||
|
||||
(maxWriteQueueSize != -1 && numRequests > maxWriteQueueSize)) {
|
||||
semaphore->connection()->flush();
|
||||
numRequests = 0;
|
||||
|
||||
Reference in New Issue
Block a user