mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Saemal/atomic signal (#96)
* code complelete * fix correctness issue * Fix correctness issuee * fix lint * ass compile * Fix build issue * Fix runtime error * Fix correctness issue * Fix crash issue * minor change * Fix memory leak * Fix review comments * Finish allgather * address comments * load element to register first then store to remote address * Finish allGather * init * Build connections * allreduce_test works * Bug fix * Add CUDA flags * Add packet copy (LL) * Lint * Set tmpPtr from constructors * Lint * Multiple blocks per peer * Beautify * Temporal ring reduce * Ring reduce works correctly * Overlapping * Fix overlapping * Improve vector sum * figuring out how to use atomics * working now * wip * Enhance LL AllReduce * Support multiple blocks per peer * Fix a ring reduce bug * Fix a AllReduce kernel 2 bug * Bug fix * wip * Make it compilable * Lint * Lint * Minor changes * Unit test to reproduce memory consistency bugs * Unit test bug fixes * Fixes * Typo * wip * done with core * wip * wip * compiles * only the atomic is failing * almost working * all tests pass now * clang-12 * More jailbreaks * bug fix for common.cu * adding stdint to concurrency.hpp * Out-of-place for AllReduce kernel 2 * Optimize `sync()` * Fix mp_unit_tests * Init TestEngine with TestArgs * Change common.cu into common.cc * Cleanup common.hpp * Lint * fixes to the mscclpp-tests * fixed common.cc --------- Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: Saeed Maleki <saemal@microsoft.com>
This commit is contained in:
@@ -146,32 +146,10 @@ struct DeviceChannel {
|
||||
put(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putDirect(void* dst, void* src, uint64_t dstOffset, uint64_t srcOffset, uint64_t size,
|
||||
uint32_t threadId, uint32_t numThreads) {
|
||||
// assume the memory is aligned to 8 bytes
|
||||
uint64_t* srcAddr = (uint64_t*)((char*)src + srcOffset);
|
||||
uint64_t* dstAddr = (uint64_t*)((char*)dst + dstOffset);
|
||||
uint64_t ele;
|
||||
size_t nElem = size % sizeof(uint64_t) ? (size + sizeof(uint64_t)) / sizeof(uint64_t) : size / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
// load to register first
|
||||
ele = srcAddr[i];
|
||||
dstAddr[i] = ele;
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signalDirect() { epoch_.signalDirect(); }
|
||||
|
||||
__forceinline__ __device__ void signalPacket() { epoch_.signalPacket(); }
|
||||
|
||||
__forceinline__ __device__ void signal() {
|
||||
epochIncrement();
|
||||
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value);
|
||||
}
|
||||
__forceinline__ __device__ void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value); }
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
epochIncrement();
|
||||
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
}
|
||||
|
||||
@@ -181,7 +159,6 @@ struct DeviceChannel {
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size) {
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, channelId_)
|
||||
.value);
|
||||
@@ -199,47 +176,6 @@ struct DeviceChannel {
|
||||
|
||||
__forceinline__ __device__ void wait() { epoch_.wait(); }
|
||||
|
||||
__forceinline__ __device__ void putPacket(void* dst, void* src, uint64_t dstOffset, uint64_t srcOffset, uint64_t size,
|
||||
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
uint32_t* srcBase = (uint32_t*)((char*)src + srcOffset);
|
||||
ChannelPacket* dstBase = (ChannelPacket*)((char*)dst + dstOffset);
|
||||
size_t nElem = size / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
ChannelPacket* pkt = &dstBase[i];
|
||||
pkt->write(srcBase[2 * i], srcBase[2 * i + 1], flag);
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putPacket(void* dst, void* src, uint64_t dstOffset, uint64_t srcOffset, uint64_t size,
|
||||
uint32_t threadId, uint32_t numThreads) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
uint32_t* srcBase = (uint32_t*)((char*)src + srcOffset);
|
||||
ChannelPacket* dstBase = (ChannelPacket*)((char*)dst + dstOffset);
|
||||
size_t nElem = size / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
ChannelPacket* pkt = &dstBase[i];
|
||||
pkt->write(srcBase[2 * i], srcBase[2 * i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void getPacket(void* dst, void* src, uint64_t dstOffset, uint64_t srcOffset, uint64_t size,
|
||||
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
ChannelPacket* srcBase = (ChannelPacket*)((char*)src + srcOffset);
|
||||
uint2* dstBase = (uint2*)((char*)dst + dstOffset);
|
||||
size_t nElem = size / sizeof(uint2);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
ChannelPacket* pkt = &srcBase[i];
|
||||
dstBase[i] = pkt->read(flag);
|
||||
// for future reuse
|
||||
pkt->clear();
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement() { epoch_.epochIncrement(); }
|
||||
|
||||
__forceinline__ __device__ uint64_t epochGetLocal() const { return epoch_.epochGetLocal(); }
|
||||
#endif // __CUDACC__
|
||||
|
||||
ChannelId channelId_;
|
||||
@@ -282,35 +218,21 @@ struct SimpleDeviceChannel {
|
||||
|
||||
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}
|
||||
|
||||
SimpleDeviceChannel(DeviceChannel devChan, void* dstPtr, void* srcPtr, void* tmpPtr = nullptr)
|
||||
: devChan_(devChan), dstPtr_(dstPtr), srcPtr_(srcPtr), tmpPtr_(tmpPtr) {}
|
||||
|
||||
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src, void* dstPtr, void* srcPtr,
|
||||
void* tmpPtr = nullptr)
|
||||
: devChan_(devChan), dst_(dst), src_(src), dstPtr_(dstPtr), srcPtr_(srcPtr), tmpPtr_(tmpPtr) {}
|
||||
SimpleDeviceChannel(DeviceChannel devChan) : devChan_(devChan) {}
|
||||
|
||||
SimpleDeviceChannel(const SimpleDeviceChannel& other) = default;
|
||||
|
||||
SimpleDeviceChannel& operator=(SimpleDeviceChannel& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
devChan_.put(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }
|
||||
|
||||
__forceinline__ __device__ void putDirect(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) {
|
||||
devChan_.putDirect(dstPtr_, srcPtr_, offset, offset, size, threadId, numThreads);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal() { devChan_.signal(); }
|
||||
|
||||
__forceinline__ __device__ void signalDirect() { devChan_.signalDirect(); }
|
||||
|
||||
__forceinline__ __device__ void signalPacket() { devChan_.signalPacket(); }
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
|
||||
devChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
@@ -329,37 +251,93 @@ struct SimpleDeviceChannel {
|
||||
|
||||
__forceinline__ __device__ void wait() { devChan_.wait(); }
|
||||
|
||||
__forceinline__ __device__ void putPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
devChan_.putPacket(dstPtr_, srcPtr_, dstOffset, srcOffset, size, threadId, numThreads, flag);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
devChan_.putPacket(dstPtr_, srcPtr_, dstOffset, srcOffset, size, threadId, numThreads);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void getPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
devChan_.getPacket(srcPtr_, tmpPtr_, dstOffset, srcOffset, size, threadId, numThreads, flag);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement() { devChan_.epochIncrement(); }
|
||||
|
||||
__forceinline__ __device__ uint64_t epochGetLocal() const { return devChan_.epochGetLocal(); }
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
DeviceChannel devChan_;
|
||||
MemoryId dst_;
|
||||
MemoryId src_;
|
||||
};
|
||||
|
||||
// these are used for direct copy
|
||||
void* dstPtr_;
|
||||
void* srcPtr_;
|
||||
// A direct version of DeviceChannel only for CudaIpc
|
||||
struct DirectChannel {
|
||||
public:
|
||||
DirectChannel() = default;
|
||||
DirectChannel(DirectEpoch::DeviceHandle epoch, RegisteredMemory dst, void* src, void* tmp = nullptr)
|
||||
: epoch_(epoch), src_(src), tmp_(tmp) {
|
||||
if (!dst.transports().has(Transport::CudaIpc)) {
|
||||
throw Error("DirectChannel: dst must be registered with CudaIpc", ErrorCode::InvalidUsage);
|
||||
}
|
||||
dst_ = dst.data();
|
||||
};
|
||||
|
||||
// extra local buffer for out-of-place copy
|
||||
void* tmpPtr_;
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
// assume the memory is aligned to 8 bytes
|
||||
uint64_t* srcAddr = (uint64_t*)((char*)src_ + srcOffset);
|
||||
uint64_t* dstAddr = (uint64_t*)((char*)dst_ + dstOffset);
|
||||
uint64_t ele;
|
||||
size_t nElem = size % sizeof(uint64_t) ? (size + sizeof(uint64_t)) / sizeof(uint64_t) : size / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
// load to register first
|
||||
ele = srcAddr[i];
|
||||
dstAddr[i] = ele;
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
uint32_t* srcBase = (uint32_t*)((char*)src_ + srcOffset);
|
||||
ChannelPacket* dstBase = (ChannelPacket*)((char*)dst_ + dstOffset);
|
||||
size_t nElem = size / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
ChannelPacket* pkt = &dstBase[i];
|
||||
pkt->write(srcBase[2 * i], srcBase[2 * i + 1], flag);
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
|
||||
uint32_t numThreads) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
uint32_t* srcBase = (uint32_t*)((char*)src_ + srcOffset);
|
||||
ChannelPacket* dstBase = (ChannelPacket*)((char*)dst_ + dstOffset);
|
||||
size_t nElem = size / sizeof(uint64_t);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
ChannelPacket* pkt = &dstBase[i];
|
||||
pkt->write(srcBase[2 * i], srcBase[2 * i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void getPacket(uint64_t dstOffset, uint64_t srcOffset, uint64_t size, uint32_t threadId,
|
||||
uint32_t numThreads, uint32_t flag) {
|
||||
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
|
||||
ChannelPacket* tmpBase = (ChannelPacket*)((char*)tmp_ + srcOffset);
|
||||
uint2* srcBase = (uint2*)((char*)src_ + dstOffset);
|
||||
size_t nElem = size / sizeof(uint2);
|
||||
for (size_t i = threadId; i < nElem; i += numThreads) {
|
||||
ChannelPacket* pkt = &tmpBase[i];
|
||||
srcBase[i] = pkt->read(flag);
|
||||
// for future reuse
|
||||
pkt->clear();
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal() { epoch_.signal(); }
|
||||
|
||||
__forceinline__ __device__ void signalPacket() { epoch_.signalPacket(); }
|
||||
|
||||
__forceinline__ __device__ void epochIncrement() { epoch_.epochIncrement(); }
|
||||
|
||||
__forceinline__ __device__ uint64_t epochGetLocal() const { return epoch_.epochGetLocal(); }
|
||||
|
||||
__forceinline__ __device__ void wait() { epoch_.wait(); }
|
||||
#endif // __CUDACC__
|
||||
private:
|
||||
DirectEpoch::DeviceHandle epoch_;
|
||||
void* src_;
|
||||
void* dst_;
|
||||
void* tmp_;
|
||||
};
|
||||
|
||||
} // namespace channel
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
#ifndef MSCCLPP_CONCURRENCY_HPP_
|
||||
#define MSCCLPP_CONCURRENCY_HPP_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <mscclpp/poll.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
struct DeviceSyncer {
|
||||
public:
|
||||
@@ -21,14 +25,12 @@ struct DeviceSyncer {
|
||||
if (atomicAdd(&count_, 1) == maxOldCnt) {
|
||||
flag_ = 1;
|
||||
}
|
||||
while (!flag_) {
|
||||
}
|
||||
POLL_MAYBE_JAILBREAK(!flag_, 1000000000);
|
||||
} else {
|
||||
if (atomicSub(&count_, 1) == 1) {
|
||||
flag_ = 0;
|
||||
}
|
||||
while (flag_) {
|
||||
}
|
||||
POLL_MAYBE_JAILBREAK(flag_, 1000000000);
|
||||
}
|
||||
isAdd_ = tmpIsAdd;
|
||||
}
|
||||
|
||||
@@ -154,10 +154,8 @@ class Communicator;
|
||||
class Connection;
|
||||
|
||||
class RegisteredMemory {
|
||||
protected:
|
||||
struct Impl;
|
||||
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
|
||||
// lazily.
|
||||
std::shared_ptr<Impl> pimpl;
|
||||
|
||||
public:
|
||||
RegisteredMemory() = default;
|
||||
@@ -173,7 +171,13 @@ class RegisteredMemory {
|
||||
static RegisteredMemory deserialize(const std::vector<char>& data);
|
||||
|
||||
friend class Connection;
|
||||
friend class IBConnection;
|
||||
friend class Communicator;
|
||||
|
||||
private:
|
||||
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
|
||||
// lazily.
|
||||
std::shared_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
class Connection {
|
||||
@@ -181,6 +185,9 @@ class Connection {
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) = 0;
|
||||
|
||||
// src must be a CPU memory
|
||||
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
|
||||
|
||||
virtual void flush() = 0;
|
||||
|
||||
virtual int remoteRank() = 0;
|
||||
|
||||
@@ -8,87 +8,107 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct alignas(16) EpochIds {
|
||||
uint64_t outbound;
|
||||
uint64_t inboundReplica;
|
||||
};
|
||||
|
||||
template <template <typename> typename Deleter>
|
||||
class BaseEpoch {
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
RegisteredMemory localEpochIdsRegMem_;
|
||||
|
||||
protected:
|
||||
NonblockingFuture<RegisteredMemory> remoteEpochIdsRegMem_;
|
||||
std::unique_ptr<EpochIds, Deleter<EpochIds>> epochIds_;
|
||||
std::unique_ptr<uint64_t, Deleter<uint64_t>> expectedInboundEpochId_;
|
||||
NonblockingFuture<RegisteredMemory> remoteInboundEpochIdsRegMem_;
|
||||
uint64_t outBoundEpochId_; // always on the host
|
||||
std::unique_ptr<uint64_t, Deleter<uint64_t>> inboundEpochId_; // could be device or host
|
||||
std::unique_ptr<uint64_t, Deleter<uint64_t>> expectedInboundEpochId_; // could be device or host
|
||||
|
||||
public:
|
||||
BaseEpoch(std::shared_ptr<Connection> connection, std::unique_ptr<EpochIds, Deleter<EpochIds>> epochIds,
|
||||
BaseEpoch(std::shared_ptr<Connection> connection, std::unique_ptr<uint64_t, Deleter<uint64_t>> inboundEpochId,
|
||||
std::unique_ptr<uint64_t, Deleter<uint64_t>> expectedInboundEpochId)
|
||||
: connection_(connection),
|
||||
epochIds_(std::move(epochIds)),
|
||||
outBoundEpochId_(0),
|
||||
inboundEpochId_(std::move(inboundEpochId)),
|
||||
expectedInboundEpochId_(std::move(expectedInboundEpochId)) {}
|
||||
|
||||
void setup(Communicator& communicator) {
|
||||
localEpochIdsRegMem_ = communicator.registerMemory(epochIds_.get(), sizeof(epochIds_), connection_->transport());
|
||||
communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection_->remoteRank(), connection_->tag());
|
||||
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection_->remoteRank(), connection_->tag());
|
||||
auto localInboundEpochIdsRegMem =
|
||||
communicator.registerMemory(inboundEpochId_.get(), sizeof(uint64_t), connection_->transport());
|
||||
communicator.sendMemoryOnSetup(localInboundEpochIdsRegMem, connection_->remoteRank(), connection_->tag());
|
||||
remoteInboundEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection_->remoteRank(), connection_->tag());
|
||||
}
|
||||
|
||||
void signal() {
|
||||
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica), localEpochIdsRegMem_,
|
||||
offsetof(EpochIds, outbound), sizeof(epochIds_));
|
||||
connection_->updateAndSync(remoteInboundEpochIdsRegMem_.get(), 0, &outBoundEpochId_, outBoundEpochId_ + 1);
|
||||
}
|
||||
};
|
||||
|
||||
class DeviceEpoch : BaseEpoch<CudaDeleter> {
|
||||
class DeviceEpoch : public BaseEpoch<CudaDeleter> {
|
||||
public:
|
||||
DeviceEpoch(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
void signal();
|
||||
|
||||
struct DeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*expectedInboundEpochId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(epochIds->inboundReplica) < (*expectedInboundEpochId), 1000000000);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement() { *(volatile uint64_t*)&(epochIds->outbound) += 1; }
|
||||
|
||||
__forceinline__ __device__ uint64_t epochGetLocal() const { return epochIds->outbound; }
|
||||
|
||||
__forceinline__ __device__ void signalDirect() {
|
||||
// This fence ensures that the writes from a preceding putDirect() are visible on the peer GPU before the
|
||||
// incremented epoch id is visible.
|
||||
__threadfence_system();
|
||||
epochIncrement();
|
||||
*(volatile uint64_t*)&(remoteEpochIds->inboundReplica) = epochIds->outbound;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signalPacket() {
|
||||
epochIncrement();
|
||||
*(volatile uint64_t*)&(remoteEpochIds->inboundReplica) = epochIds->outbound;
|
||||
POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundEpochId) < (*expectedInboundEpochId), 1000000);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
EpochIds* epochIds;
|
||||
EpochIds* remoteEpochIds;
|
||||
uint64_t* inboundEpochId;
|
||||
uint64_t* expectedInboundEpochId;
|
||||
};
|
||||
|
||||
DeviceHandle deviceHandle();
|
||||
};
|
||||
|
||||
class HostEpoch : BaseEpoch<std::default_delete> {
|
||||
class HostEpoch : public BaseEpoch<std::default_delete> {
|
||||
public:
|
||||
HostEpoch(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
|
||||
void incrementAndSignal();
|
||||
// void incrementAndSignal();
|
||||
void wait();
|
||||
};
|
||||
|
||||
class DirectEpoch {
|
||||
NonblockingFuture<RegisteredMemory> remoteInboundEpochIdsRegMem_;
|
||||
std::unique_ptr<uint64_t, CudaDeleter<uint64_t>> localInboundEpochId_;
|
||||
std::unique_ptr<uint64_t, CudaDeleter<uint64_t>> expectedInboundEpochId_;
|
||||
std::unique_ptr<uint64_t, CudaDeleter<uint64_t>> outboundEpochId_;
|
||||
|
||||
public:
|
||||
DirectEpoch(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
DirectEpoch() = default;
|
||||
struct DeviceHandle {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait() {
|
||||
(*expectedInboundEpochId) += 1;
|
||||
POLL_MAYBE_JAILBREAK(*inboundEpochId < (*expectedInboundEpochId), 1000000);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal() {
|
||||
// This fence ensures that the writes from a preceding putDirect() are visible on the peer GPU before the
|
||||
// incremented epoch id is visible.
|
||||
__threadfence_system();
|
||||
epochIncrement();
|
||||
*remoteInboundEpochId = *outboundEpochId;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signalPacket() {
|
||||
epochIncrement();
|
||||
*remoteInboundEpochId = *outboundEpochId;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement() { *outboundEpochId += 1; }
|
||||
|
||||
__forceinline__ __device__ uint64_t epochGetLocal() const { return *outboundEpochId; }
|
||||
#endif // __CUDACC__
|
||||
|
||||
volatile uint64_t* inboundEpochId;
|
||||
uint64_t* outboundEpochId;
|
||||
volatile uint64_t* remoteInboundEpochId;
|
||||
uint64_t* expectedInboundEpochId;
|
||||
};
|
||||
|
||||
DeviceHandle deviceHandle();
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EPOCH_HPP_
|
||||
|
||||
@@ -54,6 +54,19 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
|
||||
validateTransport(dst, remoteTransport());
|
||||
uint64_t oldValue = *src;
|
||||
*src = newValue;
|
||||
uint64_t* dstPtr = (uint64_t*)dst.data();
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, src, sizeof(uint64_t), cudaMemcpyHostToDevice, stream_));
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
|
||||
newValue);
|
||||
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void CudaIpcConnection::flush() {
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream_));
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
|
||||
@@ -65,8 +78,17 @@ IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communi
|
||||
: ConnectionBase(remoteRank, tag),
|
||||
transport_(transport),
|
||||
remoteTransport_(Transport::Unknown),
|
||||
numSignaledSends(0) {
|
||||
numSignaledSends(0),
|
||||
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
|
||||
qp = commImpl.getIbContext(transport)->createQp();
|
||||
dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(
|
||||
dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl));
|
||||
validateTransport(dummyAtomicSourceMem_, transport);
|
||||
dstTransportInfo_ = getRegisteredMemoryImpl(dummyAtomicSourceMem_)->getTransportInfo(transport);
|
||||
|
||||
if (!dstTransportInfo_.ibLocal) {
|
||||
throw Error("dummyAtomicSource_ is remote, which is not supported", ErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
|
||||
Transport IBConnection::transport() { return transport_; }
|
||||
@@ -93,12 +115,31 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
|
||||
qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset,
|
||||
/*signaled=*/true);
|
||||
numSignaledSends++;
|
||||
|
||||
qp->postSend();
|
||||
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset,
|
||||
(uint8_t*)dstMrInfo.addr + dstOffset, size);
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
|
||||
validateTransport(dst, remoteTransport());
|
||||
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
|
||||
if (dstTransportInfo.ibLocal) {
|
||||
throw Error("dst is local, which is not supported", ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
auto dstMrInfo = dstTransportInfo.ibMrInfo;
|
||||
// assert that src is on host
|
||||
uint64_t oldValue = *src;
|
||||
*src = newValue;
|
||||
|
||||
qp->stageAtomicAdd(dstTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset, newValue - oldValue);
|
||||
qp->postSend();
|
||||
INFO(MSCCLPP_NET, "IBConnection atomic Write: from %p to %p, %lu -> %lu", src, (uint8_t*)dstMrInfo.addr + dstOffset,
|
||||
oldValue, newValue);
|
||||
}
|
||||
|
||||
void IBConnection::flush() {
|
||||
Timer timer;
|
||||
while (numSignaledSends) {
|
||||
|
||||
41
src/epoch.cc
41
src/epoch.cc
@@ -5,37 +5,52 @@
|
||||
namespace mscclpp {
|
||||
|
||||
MSCCLPP_API_CPP DeviceEpoch::DeviceEpoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: BaseEpoch(connection, allocUniqueCuda<EpochIds>(), allocUniqueCuda<uint64_t>()) {
|
||||
: BaseEpoch(connection, allocUniqueCuda<uint64_t>(), allocUniqueCuda<uint64_t>()) {
|
||||
setup(communicator);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void DeviceEpoch::signal() { BaseEpoch::signal(); }
|
||||
|
||||
MSCCLPP_API_CPP DeviceEpoch::DeviceHandle DeviceEpoch::deviceHandle() {
|
||||
DeviceEpoch::DeviceHandle device;
|
||||
device.remoteEpochIds = reinterpret_cast<EpochIds*>(remoteEpochIdsRegMem_.get().data());
|
||||
device.epochIds = epochIds_.get();
|
||||
device.inboundEpochId = inboundEpochId_.get();
|
||||
device.expectedInboundEpochId = expectedInboundEpochId_.get();
|
||||
return device;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostEpoch::HostEpoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: BaseEpoch(connection, std::make_unique<EpochIds>(), std::make_unique<uint64_t>()) {
|
||||
: BaseEpoch(connection, std::make_unique<uint64_t>(), std::make_unique<uint64_t>()) {
|
||||
if (connection->transport() == Transport::CudaIpc) {
|
||||
throw Error("HostEpoch cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
|
||||
}
|
||||
setup(communicator);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostEpoch::incrementAndSignal() {
|
||||
*(volatile uint64_t*)&(epochIds_->outbound) += 1;
|
||||
signal();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostEpoch::wait() {
|
||||
(*expectedInboundEpochId_) += 1;
|
||||
while (*(volatile uint64_t*)&(epochIds_->inboundReplica) < (*expectedInboundEpochId_))
|
||||
;
|
||||
while (*(volatile uint64_t*)&(inboundEpochId_) < (*expectedInboundEpochId_)) {
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP DirectEpoch::DirectEpoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: expectedInboundEpochId_(allocUniqueCuda<uint64_t>()),
|
||||
outboundEpochId_(allocUniqueCuda<uint64_t>()),
|
||||
localInboundEpochId_(allocUniqueCuda<uint64_t>()) {
|
||||
if (connection->transport() != Transport::CudaIpc) {
|
||||
throw Error("DirectEpoch can only be used with CudaIpc transport", ErrorCode::InvalidUsage);
|
||||
}
|
||||
auto localInboundEpochIdsRegMem =
|
||||
communicator.registerMemory(localInboundEpochId_.get(), sizeof(uint64_t), connection->transport());
|
||||
|
||||
communicator.sendMemoryOnSetup(localInboundEpochIdsRegMem, connection->remoteRank(), connection->tag());
|
||||
remoteInboundEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP DirectEpoch::DeviceHandle DirectEpoch::deviceHandle() {
|
||||
DirectEpoch::DeviceHandle device;
|
||||
device.remoteInboundEpochId = reinterpret_cast<uint64_t*>(remoteInboundEpochIdsRegMem_.get().data());
|
||||
device.inboundEpochId = localInboundEpochId_.get();
|
||||
device.expectedInboundEpochId = expectedInboundEpochId_.get();
|
||||
device.outboundEpochId = outboundEpochId_.get();
|
||||
return device;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
76
src/ib.cc
76
src/ib.cc
@@ -30,9 +30,9 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
}
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
|
||||
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
this->mr = ibv_reg_mr(
|
||||
pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
|
||||
this->mr = ibv_reg_mr(pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);
|
||||
if (this->mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_mr failed (errno " << errno << ")";
|
||||
@@ -110,7 +110,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) {
|
||||
qpAttr.qp_state = IBV_QPS_INIT;
|
||||
qpAttr.pkey_index = 0;
|
||||
qpAttr.port_num = port;
|
||||
qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
|
||||
qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC;
|
||||
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
@@ -181,39 +181,65 @@ void IbQp::rts() {
|
||||
}
|
||||
}
|
||||
|
||||
int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled) {
|
||||
IbQp::WrInfo IbQp::getNewWrInfo() {
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
return -1;
|
||||
std::stringstream err;
|
||||
err << "too many outstanding work requests. limit is " << MSCCLPP_IB_MAX_SENDS;
|
||||
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
|
||||
struct ibv_send_wr* wr_ = &this->wrs[wrn];
|
||||
struct ibv_sge* sge_ = &this->sges[wrn];
|
||||
wr_->wr_id = wrId;
|
||||
ibv_send_wr* wr_ = &this->wrs[wrn];
|
||||
ibv_sge* sge_ = &this->sges[wrn];
|
||||
wr_->sg_list = sge_;
|
||||
wr_->num_sge = 1;
|
||||
wr_->opcode = IBV_WR_RDMA_WRITE;
|
||||
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset;
|
||||
wr_->wr.rdma.rkey = info.rkey;
|
||||
wr_->next = nullptr;
|
||||
sge_->addr = (uint64_t)(mr->getBuff()) + srcOffset;
|
||||
sge_->length = size;
|
||||
sge_->lkey = mr->getLkey();
|
||||
if (wrn > 0) {
|
||||
this->wrs[wrn - 1].next = wr_;
|
||||
}
|
||||
this->wrn++;
|
||||
return this->wrn;
|
||||
return IbQp::WrInfo{wr_, sge_};
|
||||
}
|
||||
|
||||
int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData) {
|
||||
int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
this->wrs[wrn - 1].imm_data = immData;
|
||||
return wrn;
|
||||
void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled) {
|
||||
auto wrInfo = this->getNewWrInfo();
|
||||
wrInfo.wr->wr_id = wrId;
|
||||
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE;
|
||||
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset;
|
||||
wrInfo.wr->wr.rdma.rkey = info.rkey;
|
||||
wrInfo.wr->next = nullptr;
|
||||
wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset;
|
||||
wrInfo.sge->length = size;
|
||||
wrInfo.sge->lkey = mr->getLkey();
|
||||
}
|
||||
|
||||
void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) {
|
||||
auto wrInfo = this->getNewWrInfo();
|
||||
wrInfo.wr->wr_id = wrId;
|
||||
wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD;
|
||||
wrInfo.wr->send_flags = 0; // atomic op cannot be signaled
|
||||
wrInfo.wr->wr.atomic.remote_addr = (uint64_t)(info.addr) + dstOffset;
|
||||
wrInfo.wr->wr.atomic.rkey = info.rkey;
|
||||
wrInfo.wr->wr.atomic.compare_add = addVal;
|
||||
wrInfo.sge->addr = (uint64_t)(mr->getBuff());
|
||||
wrInfo.sge->length = sizeof(uint64_t); // atomic op is always on uint64_t
|
||||
wrInfo.sge->lkey = mr->getLkey();
|
||||
}
|
||||
|
||||
void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData) {
|
||||
auto wrInfo = this->getNewWrInfo();
|
||||
wrInfo.wr->wr_id = wrId;
|
||||
wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset;
|
||||
wrInfo.wr->wr.rdma.rkey = info.rkey;
|
||||
wrInfo.wr->next = nullptr;
|
||||
wrInfo.wr->imm_data = immData;
|
||||
wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset;
|
||||
wrInfo.sge->length = size;
|
||||
wrInfo.sge->lkey = mr->getLkey();
|
||||
}
|
||||
|
||||
void IbQp::postSend() {
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "communicator.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "registered_memory.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -40,6 +41,7 @@ class CudaIpcConnection : public ConnectionBase {
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) override;
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
|
||||
|
||||
void flush() override;
|
||||
};
|
||||
@@ -49,6 +51,9 @@ class IBConnection : public ConnectionBase {
|
||||
Transport remoteTransport_;
|
||||
IbQp* qp;
|
||||
int numSignaledSends;
|
||||
std::unique_ptr<uint64_t> dummyAtomicSource_; // not used anywhere but IB needs a source
|
||||
RegisteredMemory dummyAtomicSourceMem_;
|
||||
mscclpp::TransportInfo dstTransportInfo_;
|
||||
|
||||
public:
|
||||
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
|
||||
@@ -59,6 +64,7 @@ class IBConnection : public ConnectionBase {
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) override;
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
|
||||
|
||||
void flush() override;
|
||||
|
||||
|
||||
@@ -63,10 +63,11 @@ class IbQp {
|
||||
|
||||
void rtr(const IbQpInfo& info);
|
||||
void rts();
|
||||
int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled);
|
||||
int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled);
|
||||
void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal);
|
||||
void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
void postSend();
|
||||
void postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
@@ -75,7 +76,13 @@ class IbQp {
|
||||
const ibv_wc* getWc(int idx) const;
|
||||
|
||||
private:
|
||||
struct WrInfo {
|
||||
ibv_send_wr* wr;
|
||||
ibv_sge* sge;
|
||||
};
|
||||
|
||||
IbQp(ibv_context* ctx, ibv_pd* pd, int port);
|
||||
WrInfo getNewWrInfo();
|
||||
|
||||
IbQpInfo info;
|
||||
|
||||
@@ -112,4 +119,4 @@ class IbCtx {
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
@@ -45,7 +45,6 @@ static double getTime(void) {
|
||||
__global__ void kernel(int r, int nranks, mscclpp::DeviceProxyFifo fifo, mscclpp::DeviceEpoch::DeviceHandle* handles,
|
||||
int handleIndex) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid != r) handles[tid].epochIncrement();
|
||||
__syncthreads();
|
||||
// uint64_t tail;
|
||||
if (tid == 0) {
|
||||
|
||||
@@ -80,7 +80,13 @@ class MultiProcessTestEnv : public ::testing::Environment {
|
||||
|
||||
MultiProcessTestEnv* gEnv = nullptr;
|
||||
|
||||
class MultiProcessTest : public ::testing::Test {};
|
||||
class MultiProcessTest : public ::testing::Test {
|
||||
protected:
|
||||
void TearDown() override {
|
||||
// Wait for all ranks to finish the previous test
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
@@ -258,7 +264,7 @@ static mscclpp::Transport ibIdToTransport(int id) {
|
||||
return IBs[id];
|
||||
}
|
||||
|
||||
class IbTest : public MultiProcessTest {
|
||||
class IbTestBase : public MultiProcessTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceCount(&cudaDevNum));
|
||||
@@ -274,49 +280,92 @@ class IbTest : public MultiProcessTest {
|
||||
std::string ibDevName;
|
||||
};
|
||||
|
||||
TEST_F(IbTest, SimpleSendRecv) {
|
||||
class IbPeerToPeerTest : public IbTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
IbTestBase::SetUp();
|
||||
|
||||
mscclpp::UniqueId id;
|
||||
|
||||
if (gEnv->rank < 2) {
|
||||
// This test needs only two ranks
|
||||
bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, 2);
|
||||
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
|
||||
}
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
if (gEnv->rank >= 2) {
|
||||
// This test needs only two ranks
|
||||
return;
|
||||
}
|
||||
|
||||
bootstrap->initialize(id);
|
||||
|
||||
ibCtx = std::make_shared<mscclpp::IbCtx>(ibDevName);
|
||||
qp = ibCtx->createQp();
|
||||
|
||||
qpInfo[gEnv->rank] = qp->getInfo();
|
||||
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
|
||||
}
|
||||
|
||||
void registerBufferAndConnect(void* buf, size_t size) {
|
||||
bufSize = size;
|
||||
mr = ibCtx->registerMr(buf, size);
|
||||
mrInfo[gEnv->rank] = mr->getInfo();
|
||||
bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo));
|
||||
|
||||
for (int i = 0; i < bootstrap->getNranks(); ++i) {
|
||||
if (i == gEnv->rank) continue;
|
||||
qp->rtr(qpInfo[i]);
|
||||
qp->rts();
|
||||
break;
|
||||
}
|
||||
bootstrap->barrier();
|
||||
}
|
||||
|
||||
void stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) {
|
||||
const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1];
|
||||
qp->stageSend(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled);
|
||||
}
|
||||
|
||||
void stageAtomicAdd(uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, uint64_t addVal) {
|
||||
const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1];
|
||||
qp->stageAtomicAdd(mr, remoteMrInfo, wrId, dstOffset, addVal);
|
||||
}
|
||||
|
||||
void stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled,
|
||||
unsigned int immData) {
|
||||
const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1];
|
||||
qp->stageSendWithImm(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled, immData);
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap;
|
||||
std::shared_ptr<mscclpp::IbCtx> ibCtx;
|
||||
mscclpp::IbQp* qp;
|
||||
const mscclpp::IbMr* mr;
|
||||
size_t bufSize;
|
||||
|
||||
std::array<mscclpp::IbQpInfo, 2> qpInfo;
|
||||
std::array<mscclpp::IbMrInfo, 2> mrInfo;
|
||||
};
|
||||
|
||||
TEST_F(IbPeerToPeerTest, SimpleSendRecv) {
|
||||
if (gEnv->rank >= 2) {
|
||||
// This test needs only two ranks
|
||||
return;
|
||||
}
|
||||
|
||||
mscclpp::Timer timer(3);
|
||||
mscclpp::Timer timeout(3);
|
||||
|
||||
const int maxIter = 100000;
|
||||
const int nelem = 1;
|
||||
auto data = mscclpp::allocUniqueCuda<int>(nelem);
|
||||
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(gEnv->rank, 2);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
bootstrap->initialize(id);
|
||||
|
||||
mscclpp::IbCtx ctx(ibDevName);
|
||||
mscclpp::IbQp* qp = ctx.createQp();
|
||||
const mscclpp::IbMr* mr = ctx.registerMr(data.get(), sizeof(int) * nelem);
|
||||
|
||||
std::array<mscclpp::IbQpInfo, 2> qpInfo;
|
||||
qpInfo[gEnv->rank] = qp->getInfo();
|
||||
|
||||
std::array<mscclpp::IbMrInfo, 2> mrInfo;
|
||||
mrInfo[gEnv->rank] = mr->getInfo();
|
||||
|
||||
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
|
||||
bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo));
|
||||
|
||||
for (int i = 0; i < bootstrap->getNranks(); ++i) {
|
||||
if (i == gEnv->rank) continue;
|
||||
qp->rtr(qpInfo[i]);
|
||||
qp->rts();
|
||||
break;
|
||||
}
|
||||
bootstrap->barrier();
|
||||
registerBufferAndConnect(data.get(), sizeof(int) * nelem);
|
||||
|
||||
if (gEnv->rank == 1) {
|
||||
mscclpp::Timer timer;
|
||||
for (int iter = 0; iter < maxIter; ++iter) {
|
||||
qp->stageSend(mr, mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true);
|
||||
stageSend(sizeof(int) * nelem, 0, 0, 0, true);
|
||||
qp->postSend();
|
||||
bool waiting = true;
|
||||
int spin = 0;
|
||||
@@ -335,11 +384,186 @@ TEST_F(IbTest, SimpleSendRecv) {
|
||||
}
|
||||
}
|
||||
float us = (float)timer.elapsed();
|
||||
std::cout << "IbTest.SimpleSendRecv: " << us / maxIter << " us/iter" << std::endl;
|
||||
std::cout << "IbPeerToPeerTest.SimpleSendRecv: " << us / maxIter << " us/iter" << std::endl;
|
||||
}
|
||||
bootstrap->barrier();
|
||||
}
|
||||
|
||||
__global__ void kernelMemoryConsistency(uint64_t* data, volatile uint64_t* curIter, volatile int* result,
|
||||
uint64_t nelem, uint64_t maxIter) {
|
||||
if (blockIdx.x != 0) return;
|
||||
|
||||
constexpr int FlagWrong = 1;
|
||||
constexpr int FlagAbort = 2;
|
||||
|
||||
volatile uint64_t* ptr = data;
|
||||
for (uint64_t iter = 1; iter < maxIter + 1; ++iter) {
|
||||
int err = 0;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
*curIter = iter;
|
||||
|
||||
// Wait for the first element arrival (expect equal to iter). Expect that the first element is delivered in
|
||||
// a special way that guarantees all other elements are completely delivered.
|
||||
uint64_t spin = 0;
|
||||
while (ptr[0] != iter) {
|
||||
if (spin++ == 1000000) {
|
||||
// Assume the program is stuck. Set the abort flag and escape the loop.
|
||||
*result |= FlagAbort;
|
||||
err = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Check results (expect equal to iter) in backward that is more likely to see the wrong result.
|
||||
for (size_t i = nelem - 1 + threadIdx.x; i >= blockDim.x; i -= blockDim.x) {
|
||||
if (data[i - blockDim.x] != iter) {
|
||||
#if 1
|
||||
*result |= FlagWrong;
|
||||
err = 1;
|
||||
break;
|
||||
#else
|
||||
// For debugging purposes: try waiting for the correct result.
|
||||
uint64_t spin = 0;
|
||||
while (ptr[i - blockDim.x] != iter) {
|
||||
if (spin++ == 1000000) {
|
||||
*result |= FlagAbort;
|
||||
err = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (spin >= 1000000) {
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
// Shuffle err
|
||||
for (int i = 16; i > 0; i /= 2) {
|
||||
err += __shfl_xor_sync(0xffffffff, err, i);
|
||||
}
|
||||
|
||||
if (err > 0) {
|
||||
// Exit if any error is detected.
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
*curIter = maxIter + 1;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(IbPeerToPeerTest, MemoryConsistency) {
|
||||
if (gEnv->rank >= 2) {
|
||||
// This test needs only two ranks
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t signalPeriod = 1024;
|
||||
const uint64_t maxIter = 10000;
|
||||
const uint64_t nelem = 65536 + 1;
|
||||
auto data = mscclpp::allocUniqueCuda<uint64_t>(nelem);
|
||||
|
||||
registerBufferAndConnect(data.get(), sizeof(uint64_t) * nelem);
|
||||
|
||||
uint64_t res = 0;
|
||||
uint64_t iter = 0;
|
||||
|
||||
if (gEnv->rank == 0) {
|
||||
// Receiver
|
||||
auto curIter = mscclpp::makeUniqueCudaHost<uint64_t>(0);
|
||||
auto result = mscclpp::makeUniqueCudaHost<int>(0);
|
||||
|
||||
volatile uint64_t* ptrCurIter = (volatile uint64_t*)curIter.get();
|
||||
volatile int* ptrResult = (volatile int*)result.get();
|
||||
|
||||
ASSERT_EQ(*ptrCurIter, 0);
|
||||
ASSERT_EQ(*ptrResult, 0);
|
||||
|
||||
kernelMemoryConsistency<<<1, 1024>>>(data.get(), ptrCurIter, ptrResult, nelem, maxIter);
|
||||
MSCCLPP_CUDATHROW(cudaGetLastError());
|
||||
|
||||
for (iter = 1; iter < maxIter + 1; ++iter) {
|
||||
mscclpp::Timer timeout(5);
|
||||
|
||||
while (*ptrCurIter != iter + 1) {
|
||||
res = *ptrResult;
|
||||
if (res != 0) break;
|
||||
}
|
||||
|
||||
// Send the result to the sender
|
||||
res = *ptrResult;
|
||||
uint64_t tmp[2];
|
||||
tmp[0] = res;
|
||||
bootstrap->allGather(tmp, sizeof(uint64_t));
|
||||
|
||||
if (res != 0) break;
|
||||
}
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
} else if (gEnv->rank == 1) {
|
||||
// Sender
|
||||
std::vector<uint64_t> hostBuffer(nelem, 0);
|
||||
|
||||
for (iter = 1; iter < maxIter + 1; ++iter) {
|
||||
mscclpp::Timer timeout(5);
|
||||
|
||||
// Set data
|
||||
for (uint64_t i = 0; i < nelem; i++) {
|
||||
hostBuffer[i] = iter;
|
||||
}
|
||||
mscclpp::memcpyCuda<uint64_t>(data.get(), hostBuffer.data(), nelem, cudaMemcpyHostToDevice);
|
||||
|
||||
// Need to signal from time to time to empty the IB send queue
|
||||
bool signaled = (iter % signalPeriod == 0);
|
||||
|
||||
// Send from the second element to the last
|
||||
stageSend(sizeof(uint64_t) * (nelem - 1), 0, sizeof(uint64_t), sizeof(uint64_t), signaled);
|
||||
qp->postSend();
|
||||
|
||||
#if 0
|
||||
// Send the first element using a normal send. This should occasionally see the wrong result.
|
||||
stageSend(sizeof(uint64_t), 0, 0, 0, false);
|
||||
qp->postSend();
|
||||
#else
|
||||
// For reference: send the first element using AtomicAdd. This should see the correct result.
|
||||
stageAtomicAdd(0, 0, 0, 1);
|
||||
qp->postSend();
|
||||
#endif
|
||||
|
||||
if (signaled) {
|
||||
int wcNum = qp->pollCq();
|
||||
while (wcNum == 0) {
|
||||
wcNum = qp->pollCq();
|
||||
}
|
||||
ASSERT_EQ(wcNum, 1);
|
||||
const ibv_wc* wc = qp->getWc(0);
|
||||
ASSERT_EQ(wc->status, IBV_WC_SUCCESS);
|
||||
}
|
||||
|
||||
// Get the result from the receiver
|
||||
uint64_t tmp[2];
|
||||
bootstrap->allGather(tmp, sizeof(uint64_t));
|
||||
res = tmp[0];
|
||||
|
||||
if (res != 0) break;
|
||||
}
|
||||
}
|
||||
|
||||
if (res & 2) {
|
||||
FAIL() << "The receiver is stuck at iteration " << iter << ".";
|
||||
} else if (res != 0 && res != 1) {
|
||||
FAIL() << "Unknown error is detected at iteration " << iter << ". res =" << res;
|
||||
}
|
||||
|
||||
EXPECT_EQ(res, 0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Communicator tests
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -537,13 +761,6 @@ TEST_F(CommunicatorTest, BasicWrite) {
|
||||
communicator->bootstrapper()->barrier();
|
||||
}
|
||||
|
||||
__global__ void kernelIncEpochs(mscclpp::DeviceEpoch::DeviceHandle* deviceEpochs, int rank, int worldSize) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize) {
|
||||
deviceEpochs[tid].epochIncrement();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernelWaitEpochs(mscclpp::DeviceEpoch::DeviceHandle* deviceEpochs, int rank, int worldSize) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize) {
|
||||
@@ -577,9 +794,6 @@ TEST_F(CommunicatorTest, WriteWithDeviceEpochs) {
|
||||
|
||||
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
|
||||
|
||||
kernelIncEpochs<<<1, gEnv->worldSize>>>(deviceEpochHandles.get(), gEnv->rank, gEnv->worldSize);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank) {
|
||||
epochs[i]->signal();
|
||||
@@ -613,7 +827,19 @@ TEST_F(CommunicatorTest, WriteWithHostEpochs) {
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
epochs[i]->incrementAndSignal();
|
||||
epochs[i]->signal();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
epochs[i]->wait();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
epochs[i]->signal();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -657,7 +883,7 @@ class ChannelOneToOneTest : public CommunicatorTestBase {
|
||||
}
|
||||
mscclpp::RegisteredMemory sendMemory;
|
||||
mscclpp::RegisteredMemory remoteMemory;
|
||||
void* tmpBuff = nullptr;
|
||||
// void* tmpBuff = nullptr;
|
||||
|
||||
if (isInPlace) {
|
||||
registerMemoryPair(sendBuff, sendBuffBytes, transport, 0, r, sendMemory, remoteMemory);
|
||||
@@ -665,14 +891,14 @@ class ChannelOneToOneTest : public CommunicatorTestBase {
|
||||
sendMemory = communicator->registerMemory(recvBuff, recvBuffBytes, transport);
|
||||
mscclpp::RegisteredMemory recvMemory;
|
||||
registerMemoryPair(recvBuff, recvBuffBytes, transport, 0, r, recvMemory, remoteMemory);
|
||||
tmpBuff = recvMemory.data();
|
||||
// tmpBuff = recvMemory.data();
|
||||
}
|
||||
|
||||
mscclpp::channel::ChannelId cid = channelService->addChannel(connections[r]);
|
||||
communicator->setup();
|
||||
|
||||
devChannels.emplace_back(channelService->deviceChannel(cid), channelService->addMemory(remoteMemory),
|
||||
channelService->addMemory(sendMemory), remoteMemory.data(), sendMemory.data(), tmpBuff);
|
||||
channelService->addMemory(sendMemory));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -681,9 +907,9 @@ class ChannelOneToOneTest : public CommunicatorTestBase {
|
||||
|
||||
__constant__ mscclpp::channel::SimpleDeviceChannel gChannelOneToOneTestConstDevChans;
|
||||
|
||||
__global__ void kernelPingPong(int rank, int nElem) {
|
||||
__global__ void kernelPingPong(int* buff, int rank, int nElem) {
|
||||
mscclpp::channel::SimpleDeviceChannel& devChan = gChannelOneToOneTestConstDevChans;
|
||||
volatile int* sendBuff = (volatile int*)devChan.srcPtr_;
|
||||
volatile int* sendBuff = (volatile int*)buff;
|
||||
int nTries = 1000;
|
||||
int flusher = 0;
|
||||
int rank1Offset = 10000000;
|
||||
@@ -745,16 +971,16 @@ TEST_F(ChannelOneToOneTest, PingPongIb) {
|
||||
|
||||
channelService->startProxy();
|
||||
|
||||
kernelPingPong<<<1, 1024>>>(gEnv->rank, 1);
|
||||
kernelPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
kernelPingPong<<<1, 1024>>>(gEnv->rank, 1024);
|
||||
kernelPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
kernelPingPong<<<1, 1024>>>(gEnv->rank, 1024 * 1024);
|
||||
kernelPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
kernelPingPong<<<1, 1024>>>(gEnv->rank, 4 * 1024 * 1024);
|
||||
kernelPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
channelService->stopProxy();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
function(add_mscclpp_test_executable name sources)
|
||||
add_executable(${name} ${sources} common.cu)
|
||||
add_executable(${name} ${sources} common.cc)
|
||||
target_link_libraries(${name} mscclpp MPI::MPI_CXX CUDA::cudart CUDA::cuda_driver)
|
||||
endfunction()
|
||||
|
||||
|
||||
@@ -190,10 +190,12 @@ class AllGatherTestEngine : public BaseTestEngine {
|
||||
void allocateBuffer() override;
|
||||
void setupConnections() override;
|
||||
|
||||
private:
|
||||
std::vector<void*> getSendBuff() override;
|
||||
void* getExpectedBuff() override;
|
||||
void* getRecvBuff() override;
|
||||
void* getScratchBuff() override;
|
||||
|
||||
private:
|
||||
void* getExpectedBuff() override;
|
||||
|
||||
std::shared_ptr<int> sendBuff_;
|
||||
std::shared_ptr<int[]> expectedBuff_;
|
||||
@@ -224,7 +226,10 @@ void* AllGatherTestEngine::getRecvBuff() {
|
||||
return sendBuff_.get();
|
||||
}
|
||||
|
||||
void* AllGatherTestEngine::getScratchBuff() { return nullptr; }
|
||||
|
||||
std::shared_ptr<BaseTestEngine> getTestEngine(const TestArgs& args) {
|
||||
return std::make_shared<AllGatherTestEngine>(args);
|
||||
}
|
||||
|
||||
std::shared_ptr<BaseTestColl> getTestColl() { return std::make_shared<AllGatherTestColl>(); }
|
||||
|
||||
@@ -10,7 +10,12 @@
|
||||
__constant__ mscclpp::channel::SimpleDeviceChannel constDevFstRoundChans[16];
|
||||
__constant__ mscclpp::channel::SimpleDeviceChannel constDevSndRoundChans[16];
|
||||
|
||||
static void* resultBuffer = nullptr;
|
||||
__constant__ mscclpp::channel::DirectChannel constDirChans[16];
|
||||
|
||||
// TODO(chhwang): need an interface for this.
|
||||
static void* resultBuff = nullptr;
|
||||
int* inputBuff;
|
||||
int* scratchBuff;
|
||||
|
||||
struct Chunk {
|
||||
size_t offset;
|
||||
@@ -55,7 +60,7 @@ __device__ void vectorSumSingleBlock(int* dst, int* src, size_t nElem) {
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
|
||||
__device__ void allreduce0(int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
|
||||
__device__ void allreduce0(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
|
||||
int peerId = blockIdx.x / BLOCKS_PER_PEER;
|
||||
int isComm = (threadIdx.x == 0) && (blockIdx.x % BLOCKS_PER_PEER == 0);
|
||||
int remoteRank = (peerId < rank) ? peerId : peerId + 1;
|
||||
@@ -81,14 +86,14 @@ __device__ void allreduce0(int rank, int worldSize, size_t nelems, size_t scratc
|
||||
// Local reduction: every block reduces a slice of each chunk in the scratch buffer into the user buffer
|
||||
mscclpp::channel::SimpleDeviceChannel& devSndRoundChan = constDevSndRoundChans[peerId];
|
||||
Chunk rankChunk = getChunk(nelems, worldSize, rank);
|
||||
int* chunk = (int*)devSndRoundChan.srcPtr_ + rankChunk.offset;
|
||||
int* chunk = buff + rankChunk.offset;
|
||||
int numPeers = gridDim.x / BLOCKS_PER_PEER;
|
||||
int numBlocks = gridDim.x;
|
||||
Chunk blockUserChunk = getChunk(rankChunk.size, numBlocks, blockIdx.x);
|
||||
size_t scratchDataCountPerPeer = scratchDataCount / numPeers;
|
||||
Chunk blockScratchChunk = getChunk(scratchDataCountPerPeer, numBlocks, blockIdx.x);
|
||||
for (int peerIdx = 0; peerIdx < numPeers; ++peerIdx) {
|
||||
int* scratchChunk = (int*)devFstRoundChan.tmpPtr_ + peerIdx * scratchDataCountPerPeer;
|
||||
int* scratchChunk = scratch + peerIdx * scratchDataCountPerPeer;
|
||||
vectorSumSingleBlock(chunk + blockUserChunk.offset, scratchChunk + blockScratchChunk.offset,
|
||||
blockScratchChunk.size);
|
||||
}
|
||||
@@ -106,7 +111,7 @@ __device__ void allreduce0(int rank, int worldSize, size_t nelems, size_t scratc
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void allreduce1(int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
|
||||
__device__ void allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) {
|
||||
int isComm = (threadIdx.x == 0) && (blockIdx.x == 0);
|
||||
int remoteSendRank = (rank + 1) % worldSize;
|
||||
int remoteRecvRank = (rank + worldSize - 1) % worldSize;
|
||||
@@ -142,8 +147,8 @@ __device__ void allreduce1(int rank, int worldSize, size_t nelems, size_t scratc
|
||||
// Reduce
|
||||
chunkIndex = (rank + worldSize - step) % worldSize;
|
||||
offset = chunkIndex * chunkNelem * sizeof(int);
|
||||
int* dst = (int*)((char*)devFstSendChan.srcPtr_ + offset);
|
||||
int* src = (int*)((char*)devFstRecvChan.tmpPtr_ + offset);
|
||||
int* dst = (int*)((char*)buff + offset);
|
||||
int* src = (int*)((char*)scratch + offset);
|
||||
vectorSum(dst, src, chunkNelem / 2);
|
||||
|
||||
if (isComm) {
|
||||
@@ -171,8 +176,8 @@ __device__ void allreduce1(int rank, int worldSize, size_t nelems, size_t scratc
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
offset = rank * chunkNelem * sizeof(int);
|
||||
int* dst = (int*)((char*)devFstSendChan.srcPtr_ + offset);
|
||||
int* src = (int*)((char*)devFstRecvChan.tmpPtr_ + offset);
|
||||
int* dst = (int*)((char*)buff + offset);
|
||||
int* src = (int*)((char*)scratch + offset);
|
||||
vectorSum(dst, src, chunkNelem / 2);
|
||||
|
||||
if (isComm) {
|
||||
@@ -217,26 +222,25 @@ __device__ void allreduce1(int rank, int worldSize, size_t nelems, size_t scratc
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void allreduce2(int rank, int worldSize, size_t nelems, void* resultBuff) {
|
||||
__device__ void allreduce2(int* buff, int* scratch, void* result, int rank, int worldSize, size_t nelems) {
|
||||
int chanIdx = blockIdx.x / BLOCKS_PER_PEER;
|
||||
int numPeers = worldSize - 1;
|
||||
size_t nPkts = nelems / 2; // 2 elems per packet, assume nelems is even
|
||||
size_t pktBytes = nPkts * sizeof(mscclpp::channel::ChannelPacket);
|
||||
mscclpp::channel::SimpleDeviceChannel devFstRoundChan = constDevFstRoundChans[chanIdx];
|
||||
uint32_t flag = (uint32_t)devFstRoundChan.epochGetLocal() + 1; // +1 as flag should be non-zero
|
||||
mscclpp::channel::DirectChannel devDirChan = constDirChans[chanIdx];
|
||||
uint32_t flag = (uint32_t)devDirChan.epochGetLocal() + 1; // +1 as flag should be non-zero
|
||||
size_t srcOffset =
|
||||
((blockIdx.x % BLOCKS_PER_PEER) * nelems * sizeof(int) / BLOCKS_PER_PEER); // offset for this block
|
||||
size_t dstOffset = ((flag & 1) ? 0 : pktBytes * numPeers) + // double buffering
|
||||
((chanIdx < rank ? rank - 1 : rank) * pktBytes) + // offset for this rank
|
||||
(srcOffset * 2); // offset for this block: twice of srcOffset because 2 elems per packet
|
||||
|
||||
devFstRoundChan.putPacket(dstOffset, srcOffset, nelems / BLOCKS_PER_PEER * sizeof(int), threadIdx.x, blockDim.x,
|
||||
flag);
|
||||
devDirChan.putPacket(dstOffset, srcOffset, nelems / BLOCKS_PER_PEER * sizeof(int), threadIdx.x, blockDim.x, flag);
|
||||
|
||||
int2* src = (int2*)devFstRoundChan.srcPtr_;
|
||||
int2* res = (int2*)resultBuff; // cumulate into here
|
||||
mscclpp::channel::ChannelPacket* tmpPtr = (mscclpp::channel::ChannelPacket*)devFstRoundChan.tmpPtr_ +
|
||||
((flag & 1) ? 0 : nPkts * numPeers); // double buffering
|
||||
int2* src = (int2*)buff;
|
||||
int2* res = (int2*)result; // cumulate into here
|
||||
mscclpp::channel::ChannelPacket* tmpPtr =
|
||||
(mscclpp::channel::ChannelPacket*)scratch + ((flag & 1) ? 0 : nPkts * numPeers); // double buffering
|
||||
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPkts; idx += blockDim.x * gridDim.x) {
|
||||
int x = 0;
|
||||
int y = 0;
|
||||
@@ -261,17 +265,18 @@ __device__ void allreduce2(int rank, int worldSize, size_t nelems, void* resultB
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && (blockIdx.x % BLOCKS_PER_PEER) == 0) {
|
||||
devFstRoundChan.epochIncrement();
|
||||
devDirChan.epochIncrement();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int worldSize, size_t nelems, size_t scratchDataCount, void* resultBuff, int kernel) {
|
||||
__global__ void kernel(int* buff, int* scratch, void* result, int rank, int worldSize, size_t nelems,
|
||||
size_t scratchDataCount, int kernel) {
|
||||
if (kernel == 0)
|
||||
allreduce0(rank, worldSize, nelems, scratchDataCount);
|
||||
allreduce0(buff, scratch, rank, worldSize, nelems, scratchDataCount);
|
||||
else if (kernel == 1)
|
||||
allreduce1(rank, worldSize, nelems, scratchDataCount);
|
||||
allreduce1(buff, scratch, rank, worldSize, nelems, scratchDataCount);
|
||||
else if (kernel == 2)
|
||||
allreduce2(rank, worldSize, nelems, resultBuff);
|
||||
allreduce2(buff, scratch, result, rank, worldSize, nelems);
|
||||
}
|
||||
|
||||
class AllReduceTestColl : public BaseTestColl {
|
||||
@@ -293,7 +298,8 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
const Chunk chunk = getChunk(paramCount_, worldSize, rank);
|
||||
const size_t scratchDataCount = chunk.size * nPeers;
|
||||
const int nBlocks = (kernelNum == 1) ? 24 : nPeers * BLOCKS_PER_PEER;
|
||||
kernel<<<nBlocks, 1024, 0, stream>>>(rank, worldSize, paramCount_, scratchDataCount, resultBuffer, kernelNum);
|
||||
kernel<<<nBlocks, 1024, 0, stream>>>(inputBuff, scratchBuff, resultBuff, rank, worldSize, paramCount_,
|
||||
scratchDataCount, kernelNum);
|
||||
}
|
||||
|
||||
void AllReduceTestColl::initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) {
|
||||
@@ -336,12 +342,15 @@ class AllReduceTestEngine : public BaseTestEngine {
|
||||
void allocateBuffer() override;
|
||||
void setupConnections() override;
|
||||
|
||||
private:
|
||||
bool isUsePacket() const;
|
||||
bool isInPlace() const;
|
||||
|
||||
std::vector<void*> getSendBuff() override;
|
||||
void* getExpectedBuff() override;
|
||||
void* getRecvBuff() override;
|
||||
void* getScratchBuff() override;
|
||||
|
||||
private:
|
||||
void* getExpectedBuff() override;
|
||||
|
||||
std::shared_ptr<int> sendBuff_;
|
||||
std::shared_ptr<int> scratchBuff_;
|
||||
@@ -363,30 +372,37 @@ void AllReduceTestEngine::allocateBuffer() {
|
||||
resultBuff_ = mscclpp::allocSharedCuda<int>(args_.maxBytes / sizeof(int));
|
||||
expectedBuff_ = std::shared_ptr<int[]>(new int[args_.maxBytes / sizeof(int)]);
|
||||
|
||||
inputBuff = sendBuff_.get();
|
||||
scratchBuff = scratchBuff_.get();
|
||||
// TODO(chhwang): need a new interface for this.
|
||||
resultBuffer = resultBuff_.get();
|
||||
resultBuff = resultBuff_.get();
|
||||
}
|
||||
|
||||
void AllReduceTestEngine::setupConnections() {
|
||||
std::vector<mscclpp::channel::SimpleDeviceChannel> fstRoundChannels;
|
||||
std::vector<mscclpp::channel::SimpleDeviceChannel> sndRoundChannels;
|
||||
|
||||
// Send data from local sendBuff to remote scratchBuff (out-of-place)
|
||||
setupMeshConnections(fstRoundChannels, sendBuff_.get(), args_.maxBytes, scratchBuff_.get(), args_.maxBytes);
|
||||
assert(fstRoundChannels.size() < sizeof(constDevFstRoundChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
|
||||
CUDATHROW(cudaMemcpyToSymbol(constDevFstRoundChans, fstRoundChannels.data(),
|
||||
sizeof(mscclpp::channel::SimpleDeviceChannel) * fstRoundChannels.size()));
|
||||
|
||||
if (isUsePacket()) {
|
||||
// Send data from local sendBuff to remote scratchBuff (out-of-place)
|
||||
setupMeshConnections(sndRoundChannels, sendBuff_.get(), args_.maxBytes, scratchBuff_.get(), args_.maxBytes);
|
||||
std::vector<mscclpp::channel::DirectChannel> dirChannels;
|
||||
|
||||
setupMeshConnections(dirChannels, sendBuff_.get(), args_.maxBytes, scratchBuff_.get(), args_.maxBytes);
|
||||
|
||||
assert(dirChannels.size() < sizeof(constDirChans) / sizeof(mscclpp::channel::DirectChannel));
|
||||
CUDATHROW(cudaMemcpyToSymbol(constDirChans, dirChannels.data(),
|
||||
sizeof(mscclpp::channel::DirectChannel) * dirChannels.size()));
|
||||
} else {
|
||||
std::vector<mscclpp::channel::SimpleDeviceChannel> fstRoundChannels;
|
||||
std::vector<mscclpp::channel::SimpleDeviceChannel> sndRoundChannels;
|
||||
|
||||
// Send data from local sendBuff to remote scratchBuff (out-of-place)
|
||||
setupMeshConnections(fstRoundChannels, sendBuff_.get(), args_.maxBytes, scratchBuff_.get(), args_.maxBytes);
|
||||
assert(fstRoundChannels.size() < sizeof(constDevFstRoundChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
|
||||
CUDATHROW(cudaMemcpyToSymbol(constDevFstRoundChans, fstRoundChannels.data(),
|
||||
sizeof(mscclpp::channel::SimpleDeviceChannel) * fstRoundChannels.size()));
|
||||
|
||||
// Send data from local sendBuff to remote sendBuff (in-place)
|
||||
setupMeshConnections(sndRoundChannels, sendBuff_.get(), args_.maxBytes);
|
||||
assert(sndRoundChannels.size() < sizeof(constDevSndRoundChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
|
||||
CUDATHROW(cudaMemcpyToSymbol(constDevSndRoundChans, sndRoundChannels.data(),
|
||||
sizeof(mscclpp::channel::SimpleDeviceChannel) * sndRoundChannels.size()));
|
||||
}
|
||||
assert(sndRoundChannels.size() < sizeof(constDevSndRoundChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
|
||||
CUDATHROW(cudaMemcpyToSymbol(constDevSndRoundChans, sndRoundChannels.data(),
|
||||
sizeof(mscclpp::channel::SimpleDeviceChannel) * sndRoundChannels.size()));
|
||||
}
|
||||
|
||||
std::vector<void*> AllReduceTestEngine::getSendBuff() { return {sendBuff_.get()}; }
|
||||
@@ -395,7 +411,10 @@ void* AllReduceTestEngine::getExpectedBuff() { return expectedBuff_.get(); }
|
||||
|
||||
void* AllReduceTestEngine::getRecvBuff() { return isInPlace() ? sendBuff_.get() : resultBuff_.get(); }
|
||||
|
||||
void* AllReduceTestEngine::getScratchBuff() { return scratchBuff_.get(); }
|
||||
|
||||
std::shared_ptr<BaseTestEngine> getTestEngine(const TestArgs& args) {
|
||||
return std::make_shared<AllReduceTestEngine>(args);
|
||||
}
|
||||
|
||||
std::shared_ptr<BaseTestColl> getTestColl() { return std::make_shared<AllReduceTestColl>(); }
|
||||
|
||||
@@ -118,10 +118,12 @@ class AllToAllTestEngine : public BaseTestEngine {
|
||||
void allocateBuffer() override;
|
||||
void setupConnections() override;
|
||||
|
||||
private:
|
||||
std::vector<void*> getSendBuff() override;
|
||||
void* getExpectedBuff() override;
|
||||
void* getRecvBuff() override;
|
||||
void* getScratchBuff() override;
|
||||
|
||||
private:
|
||||
void* getExpectedBuff() override;
|
||||
|
||||
std::shared_ptr<int> sendBuff_;
|
||||
std::shared_ptr<int> recvBuff_;
|
||||
@@ -151,6 +153,7 @@ void AllToAllTestEngine::setupConnections() {
|
||||
std::vector<void*> AllToAllTestEngine::getSendBuff() { return {sendBuff_.get()}; }
|
||||
void* AllToAllTestEngine::getExpectedBuff() { return expectedBuff_.get(); }
|
||||
void* AllToAllTestEngine::getRecvBuff() { return recvBuff_.get(); }
|
||||
void* AllToAllTestEngine::getScratchBuff() { return nullptr; }
|
||||
|
||||
std::shared_ptr<BaseTestEngine> getTestEngine(const TestArgs& args) {
|
||||
return std::make_shared<AllToAllTestEngine>(args);
|
||||
|
||||
@@ -1,23 +1,33 @@
|
||||
#include "common.hpp"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <getopt.h>
|
||||
#include <libgen.h>
|
||||
#include <mpi.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
int is_main_proc = 0;
|
||||
int isMainProc = 0;
|
||||
|
||||
mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2,
|
||||
mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
|
||||
mscclpp::Transport::IB6, mscclpp::Transport::IB7};
|
||||
|
||||
#define PRINT(__message) \
|
||||
do { \
|
||||
if (isMainProc) std::cout << __message; \
|
||||
} while (0);
|
||||
|
||||
#define PRECISION(__val) std::fixed << std::setprecision(2) << __val
|
||||
|
||||
namespace {
|
||||
|
||||
// Command line parameter defaults
|
||||
@@ -149,19 +159,24 @@ void BaseTestEngine::runTest() {
|
||||
}
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
PRINT("#\n");
|
||||
PRINT("# %10s %12s in-place out-of-place \n", "", "");
|
||||
PRINT("# %10s %12s %7s %6s %6s %6s %7s %6s %6s %6s\n", "size", "count", "time", "algbw", "busbw", "#wrong",
|
||||
"time", "algbw", "busbw", "#wrong");
|
||||
PRINT("# %10s %12s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "(us)", "(GB/s)", "(GB/s)", "",
|
||||
"(us)", "(GB/s)", "(GB/s)", "");
|
||||
std::stringstream ss;
|
||||
ss << "#\n";
|
||||
ss << "# in-place out-of-place\n";
|
||||
ss << "# size count time algbw busbw #wrong time algbw busbw #wrong\n";
|
||||
ss << "# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)\n";
|
||||
PRINT(ss.str());
|
||||
|
||||
ss.str(std::string());
|
||||
|
||||
// Benchmark
|
||||
for (size_t size = args_.minBytes; size <= args_.maxBytes;
|
||||
size = ((args_.stepFactor > 1) ? size * args_.stepFactor : size + args_.stepBytes)) {
|
||||
coll_->setupCollTest(args_, size);
|
||||
this->coll_->initData(this->args_, this->getSendBuff(), this->getExpectedBuff());
|
||||
PRINT("%12li %12li", max(coll_->getSendBytes(), coll_->getExpectedBytes()), coll_->getParamBytes() / sizeof(int));
|
||||
|
||||
ss << std::setw(12) << std::max(coll_->getSendBytes(), coll_->getExpectedBytes()) << " " << std::setw(12)
|
||||
<< coll_->getParamBytes() / sizeof(int);
|
||||
|
||||
double deltaSec = benchTime();
|
||||
|
||||
size_t nErrors = 0;
|
||||
@@ -191,14 +206,18 @@ void BaseTestEngine::runTest() {
|
||||
double algBw, busBw;
|
||||
this->coll_->getBw(deltaSec, algBw, busBw);
|
||||
if (!this->inPlace_) {
|
||||
PRINT(" ");
|
||||
ss << " ";
|
||||
}
|
||||
if (args_.reportErrors) {
|
||||
PRINT(" %7s %6.2f %6.2f %5g", timeStr, algBw, busBw, (double)nErrors);
|
||||
ss << " " << std::setw(7) << timeStr << " " << std::setw(6) << PRECISION(algBw) << " " << std::setw(6)
|
||||
<< PRECISION(busBw) << " " << std::setw(5) << nErrors;
|
||||
} else {
|
||||
PRINT(" %7s %6.2f %6.2f %5s", timeStr, algBw, busBw, "N/A");
|
||||
ss << " " << std::setw(7) << timeStr << " " << std::setw(6) << PRECISION(algBw) << " " << std::setw(6)
|
||||
<< PRECISION(busBw);
|
||||
}
|
||||
PRINT("\n");
|
||||
ss << "\n";
|
||||
PRINT(ss.str());
|
||||
ss.str(std::string());
|
||||
}
|
||||
PRINT("\n");
|
||||
}
|
||||
@@ -234,20 +253,22 @@ size_t BaseTestEngine::checkData() {
|
||||
return nErrors;
|
||||
}
|
||||
|
||||
// Create mesh connections between all ranks. If recvBuff is nullptr, assume in-place.
|
||||
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::channel::SimpleDeviceChannel>& devChannels,
|
||||
void* sendBuff, size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) {
|
||||
void BaseTestEngine::setupMeshConnectionsInternal(
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& inputBufRegMem,
|
||||
mscclpp::RegisteredMemory& outputBufRegMem,
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteRegMemories, void* inputBuff,
|
||||
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes) {
|
||||
const int worldSize = args_.totalRanks;
|
||||
const int rank = args_.rank;
|
||||
const int nRanksPerNode = args_.nRanksPerNode;
|
||||
const int thisNode = rank / nRanksPerNode;
|
||||
const mscclpp::Transport ibTransport = IBs[args_.gpuNum];
|
||||
const bool isOutPlace = (recvBuff != nullptr);
|
||||
const bool isOutPlace = (outputBuff != nullptr);
|
||||
|
||||
std::vector<mscclpp::channel::ChannelId> channelIds;
|
||||
std::vector<mscclpp::RegisteredMemory> localMemories;
|
||||
std::vector<mscclpp::RegisteredMemory> localTmpMemories;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
|
||||
inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
if (isOutPlace) {
|
||||
outputBufRegMem = comm_->registerMemory(outputBuff, outputBuffBytes, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
}
|
||||
|
||||
auto rankToNode = [&](int rank) { return rank / nRanksPerNode; };
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
@@ -261,25 +282,61 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::channel::SimpleDe
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
channelIds.push_back(chanService_->addChannel(comm_->connectOnSetup(r, 0, transport)));
|
||||
auto sendMemory = comm_->registerMemory(sendBuff, sendBuffBytes, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
localMemories.push_back(sendMemory);
|
||||
connections.push_back(comm_->connectOnSetup(r, 0, transport));
|
||||
|
||||
if (isOutPlace) {
|
||||
auto recvMemory = comm_->registerMemory(recvBuff, recvBuffBytes, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
comm_->sendMemoryOnSetup(recvMemory, r, 0);
|
||||
localTmpMemories.push_back(recvMemory);
|
||||
comm_->sendMemoryOnSetup(outputBufRegMem, r, 0);
|
||||
} else {
|
||||
comm_->sendMemoryOnSetup(sendMemory, r, 0);
|
||||
comm_->sendMemoryOnSetup(inputBufRegMem, r, 0);
|
||||
}
|
||||
remoteMemories.push_back(comm_->recvMemoryOnSetup(r, 0));
|
||||
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
|
||||
remoteRegMemories.push_back(remoteMemory);
|
||||
}
|
||||
comm_->setup();
|
||||
}
|
||||
|
||||
// Create mesh connections between all ranks. If recvBuff is nullptr, assume in-place.
|
||||
// TODO(saemal): retrun the actual vector instead of void
|
||||
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::channel::SimpleDeviceChannel>& devChannels,
|
||||
void* inputBuff, size_t inputBuffBytes, void* outputBuff,
|
||||
size_t outputBuffBytes) {
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
mscclpp::RegisteredMemory inputBufRegMem;
|
||||
mscclpp::RegisteredMemory outputBufRegMem;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
|
||||
setupMeshConnectionsInternal(connections, inputBufRegMem, outputBufRegMem, remoteRegMemories, inputBuff,
|
||||
inputBuffBytes, outputBuff, outputBuffBytes);
|
||||
|
||||
for (size_t i = 0; i < connections.size(); ++i) {
|
||||
devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(
|
||||
chanService_->deviceChannel(chanService_->addChannel(connections[i])),
|
||||
chanService_->addMemory(remoteRegMemories[i].get()), chanService_->addMemory(inputBufRegMem)));
|
||||
}
|
||||
|
||||
comm_->setup();
|
||||
}
|
||||
|
||||
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::channel::DirectChannel>& dirChannels, void* inputBuff,
|
||||
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes) {
|
||||
const bool isOutPlace = (outputBuff != nullptr);
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
mscclpp::RegisteredMemory inputBufRegMem;
|
||||
mscclpp::RegisteredMemory outputBufRegMem;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
|
||||
setupMeshConnectionsInternal(connections, inputBufRegMem, outputBufRegMem, remoteRegMemories, inputBuff,
|
||||
inputBuffBytes, outputBuff, outputBuffBytes);
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::DirectEpoch>> dirEpochs;
|
||||
for (auto& conn : connections) {
|
||||
dirEpochs.emplace_back(std::make_shared<mscclpp::DirectEpoch>(*comm_, conn));
|
||||
}
|
||||
comm_->setup();
|
||||
|
||||
for (size_t i = 0; i < channelIds.size(); ++i) {
|
||||
devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(
|
||||
chanService_->deviceChannel(channelIds[i]), chanService_->addMemory(remoteMemories[i].get()),
|
||||
chanService_->addMemory(localMemories[i]), remoteMemories[i].get().data(), localMemories[i].data(),
|
||||
(isOutPlace ? localTmpMemories[i].data() : nullptr)));
|
||||
for (size_t i = 0; i < dirEpochs.size(); ++i) {
|
||||
dirChannels.emplace_back(dirEpochs[i]->deviceHandle(), remoteRegMemories[i].get(), inputBufRegMem.data(),
|
||||
(isOutPlace ? outputBufRegMem.data() : nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,15 +455,16 @@ void run(int argc, char* argv[]) {
|
||||
MPI_Comm_size(shmcomm, &nRanksPerNode);
|
||||
MPI_Comm_free(&shmcomm);
|
||||
localRank = rank % nRanksPerNode;
|
||||
is_main_proc = (rank == 0) ? 1 : 0;
|
||||
isMainProc = (rank == 0) ? 1 : 0;
|
||||
|
||||
PRINT(
|
||||
"# minBytes %ld maxBytes %ld step: %ld(%s) warmup iters: %d iters: %d validation: %d graph: %d, "
|
||||
"kernel num: %d\n",
|
||||
minBytes, maxBytes, (stepFactor > 1) ? stepFactor : stepBytes, (stepFactor > 1) ? "factor" : "bytes",
|
||||
warmup_iters, iters, datacheck, cudaGraphLaunches, kernel_num);
|
||||
PRINT("#\n");
|
||||
PRINT("# Using devices\n");
|
||||
std::stringstream ss;
|
||||
ss << "# minBytes " << minBytes << " maxBytes " << maxBytes
|
||||
<< " step: " << ((stepFactor > 1) ? stepFactor : stepBytes) << "(" << ((stepFactor > 1) ? "factor" : "bytes")
|
||||
<< ") warmup iters: " << warmup_iters << " iters: " << iters << " validation: " << datacheck
|
||||
<< " graph: " << cudaGraphLaunches << " kernel num: " << kernel_num << "\n";
|
||||
ss << "#\n# Using devices\n";
|
||||
PRINT(ss.str());
|
||||
ss.str(std::string());
|
||||
|
||||
constexpr int MAX_LINE = 2048;
|
||||
char line[MAX_LINE];
|
||||
@@ -426,7 +484,11 @@ void run(int argc, char* argv[]) {
|
||||
// Gather all output in rank order to root (0)
|
||||
MPI_Gather(line, MAX_LINE, MPI_BYTE, lines.get(), MAX_LINE, MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
if (rank == 0) {
|
||||
for (int r = 0; r < totalRanks; r++) PRINT("%s", &lines[MAX_LINE * r]);
|
||||
for (int r = 0; r < totalRanks; r++) {
|
||||
ss << &lines[MAX_LINE * r];
|
||||
}
|
||||
PRINT(ss.str());
|
||||
ss.str(std::string());
|
||||
}
|
||||
MPI_Allreduce(MPI_IN_PLACE, &maxMem, 1, MPI_LONG, MPI_MIN, MPI_COMM_WORLD);
|
||||
|
||||
@@ -434,17 +496,22 @@ void run(int argc, char* argv[]) {
|
||||
size_t memMaxBytes = (maxMem - (1 << 30)) / (datacheck ? 3 : 2);
|
||||
if (maxBytes > memMaxBytes) {
|
||||
maxBytes = memMaxBytes;
|
||||
PRINT("#\n# Reducing maxBytes to %ld due to memory limitation\n", maxBytes);
|
||||
ss << "#\n# Reducing maxBytes to " << maxBytes << " due to memory limitation\n";
|
||||
PRINT(ss.str());
|
||||
ss.str(std::string());
|
||||
}
|
||||
|
||||
CUDATHROW(cudaSetDevice(cudaDev));
|
||||
TestArgs args = {minBytes, maxBytes, stepBytes, stepFactor, totalRanks, rank,
|
||||
cudaDev, localRank, nRanksPerNode, kernel_num, datacheck};
|
||||
PRINT("#\n");
|
||||
PRINT("# Initializing MSCCL++\n");
|
||||
|
||||
PRINT("#\n# Initializing MSCCL++\n");
|
||||
|
||||
auto testEngine = getTestEngine(args);
|
||||
testEngine->bootstrap();
|
||||
testEngine->allocateBuffer();
|
||||
int* inputBuff = (int*)testEngine->getSendBuff()[0];
|
||||
int* scratchBuff = (int*)testEngine->getScratchBuff();
|
||||
PRINT("# Setting up the connection in MSCCL++\n");
|
||||
testEngine->setupTest();
|
||||
testEngine->barrier();
|
||||
@@ -455,8 +522,8 @@ void run(int argc, char* argv[]) {
|
||||
int error = testEngine->getTestErrors();
|
||||
MPI_Allreduce(MPI_IN_PLACE, &error, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
|
||||
|
||||
PRINT("# Out of bounds values : %d %s\n", error, error ? "FAILED" : "OK");
|
||||
PRINT("#\n");
|
||||
ss << "# Out of bounds values : " << error << " " << (error ? "FAILED" : "OK") << "\n#\n";
|
||||
PRINT(ss.str());
|
||||
|
||||
MPI_Finalize();
|
||||
}
|
||||
@@ -1,12 +1,6 @@
|
||||
#ifndef MSCCLPP_TESTS_COMMON_H_
|
||||
#define MSCCLPP_TESTS_COMMON_H_
|
||||
|
||||
#include <mpi.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <mscclpp/channel.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
@@ -74,17 +68,27 @@ class BaseTestEngine {
|
||||
void barrier();
|
||||
size_t checkData();
|
||||
|
||||
virtual std::vector<void*> getSendBuff() = 0;
|
||||
virtual void* getRecvBuff() = 0;
|
||||
virtual void* getScratchBuff() = 0;
|
||||
|
||||
private:
|
||||
virtual void setupConnections() = 0;
|
||||
virtual std::vector<void*> getSendBuff() = 0;
|
||||
virtual void* getExpectedBuff() = 0;
|
||||
virtual void* getRecvBuff() = 0;
|
||||
|
||||
double benchTime();
|
||||
|
||||
void setupMeshConnectionsInternal(
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& inputBufRegMem,
|
||||
mscclpp::RegisteredMemory& outputBufRegMem,
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteRegMemories, void* inputBuff,
|
||||
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes);
|
||||
|
||||
protected:
|
||||
void setupMeshConnections(std::vector<mscclpp::channel::SimpleDeviceChannel>& devChannels, void* sendBuff,
|
||||
size_t sendBuffBytes, void* recvBuff = nullptr, size_t recvBuffBytes = 0);
|
||||
void setupMeshConnections(std::vector<mscclpp::channel::SimpleDeviceChannel>& devChannels, void* inputBuff,
|
||||
size_t inputBuffBytes, void* outputBuff = nullptr, size_t outputBuffBytes = 0);
|
||||
void setupMeshConnections(std::vector<mscclpp::channel::DirectChannel>& dirChannels, void* inputBuff,
|
||||
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes = 0);
|
||||
|
||||
const TestArgs args_;
|
||||
bool inPlace_;
|
||||
@@ -99,7 +103,4 @@ extern std::shared_ptr<BaseTestEngine> getTestEngine(const TestArgs& args);
|
||||
extern std::shared_ptr<BaseTestColl> getTestColl();
|
||||
extern mscclpp::Transport IBs[];
|
||||
|
||||
#define PRINT \
|
||||
if (is_main_proc) printf
|
||||
|
||||
#endif // MSCCLPP_TESTS_COMMON_H_
|
||||
|
||||
@@ -19,7 +19,7 @@ constexpr size_t MAX_BLOCKS_NUM = 32;
|
||||
|
||||
#define ALIGN 4
|
||||
|
||||
__constant__ mscclpp::channel::SimpleDeviceChannel constDevChans[2];
|
||||
__constant__ mscclpp::channel::DirectChannel constDirChans[2];
|
||||
|
||||
inline int getBlockNum(size_t count) {
|
||||
return std::min((count + THRES_BYTES_PER_BLOCK - 1) / THRES_BYTES_PER_BLOCK, MAX_BLOCKS_NUM);
|
||||
@@ -36,13 +36,13 @@ __global__ void kernel(int rank, size_t dataSize, size_t dataPerBlock) {
|
||||
size_t blockDataSize = min(dataSize - startIndex, dataPerBlock);
|
||||
int globalIndex = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
mscclpp::channel::SimpleDeviceChannel sendConn = constDevChans[0];
|
||||
mscclpp::channel::SimpleDeviceChannel recvConn = constDevChans[1];
|
||||
mscclpp::channel::DirectChannel sendConn = constDirChans[0];
|
||||
mscclpp::channel::DirectChannel recvConn = constDirChans[1];
|
||||
|
||||
sendConn.putDirect(startIndex, blockDataSize, threadIdx.x, blockDim.x);
|
||||
sendConn.put(startIndex, startIndex, blockDataSize, threadIdx.x, blockDim.x);
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
if (globalIndex == 0) {
|
||||
sendConn.signalDirect();
|
||||
sendConn.signal();
|
||||
recvConn.wait();
|
||||
}
|
||||
}
|
||||
@@ -109,10 +109,12 @@ class SendRecvTestEngine : public BaseTestEngine {
|
||||
void allocateBuffer() override;
|
||||
void setupConnections() override;
|
||||
|
||||
private:
|
||||
std::vector<void*> getSendBuff() override;
|
||||
void* getExpectedBuff() override;
|
||||
void* getRecvBuff() override;
|
||||
void* getScratchBuff() override;
|
||||
|
||||
private:
|
||||
void* getExpectedBuff() override;
|
||||
|
||||
std::vector<std::shared_ptr<int>> devicePtrs_;
|
||||
std::shared_ptr<int[]> expectedBuff_;
|
||||
@@ -136,16 +138,18 @@ void SendRecvTestEngine::setupConnections() {
|
||||
int recvFromRank = (args_.rank - 1 + worldSize) % worldSize;
|
||||
std::array<int, 2> ranks = {sendToRank, recvFromRank};
|
||||
|
||||
std::vector<mscclpp::channel::ChannelId> chanIds;
|
||||
std::vector<std::shared_ptr<mscclpp::DirectEpoch>> directEpochs;
|
||||
|
||||
chanIds.push_back(chanService_->addChannel(
|
||||
comm_->connectOnSetup(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice))));
|
||||
auto sendConn =
|
||||
comm_->connectOnSetup(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice));
|
||||
directEpochs.push_back(std::make_shared<mscclpp::DirectEpoch>(*comm_, sendConn));
|
||||
if (recvFromRank != sendToRank) {
|
||||
chanIds.push_back(chanService_->addChannel(
|
||||
comm_->connectOnSetup(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice))));
|
||||
auto recvConn =
|
||||
comm_->connectOnSetup(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice));
|
||||
directEpochs.push_back(std::make_shared<mscclpp::DirectEpoch>(*comm_, recvConn));
|
||||
} else {
|
||||
// reuse the send channel if worldSize is 2
|
||||
chanIds.push_back(chanIds[0]);
|
||||
directEpochs.push_back(directEpochs[0]);
|
||||
}
|
||||
comm_->setup();
|
||||
|
||||
@@ -162,14 +166,13 @@ void SendRecvTestEngine::setupConnections() {
|
||||
|
||||
// swap to make sure devicePtrs_[0] in local rank write to devicePtrs_[1] in remote rank
|
||||
std::swap(futureRemoteMemory[0], futureRemoteMemory[1]);
|
||||
std::vector<mscclpp::channel::SimpleDeviceChannel> devChannels;
|
||||
std::vector<mscclpp::channel::DirectChannel> dirChannels;
|
||||
for (int i : {0, 1}) {
|
||||
// We assume ranks in the same node
|
||||
devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(
|
||||
chanService_->deviceChannel(chanIds[i]), futureRemoteMemory[i].get().data(), localMemories[i].data()));
|
||||
dirChannels.emplace_back(directEpochs[i]->deviceHandle(), futureRemoteMemory[i].get(),
|
||||
(void*)localMemories[i].data());
|
||||
}
|
||||
cudaMemcpyToSymbol(constDevChans, devChannels.data(),
|
||||
sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size());
|
||||
cudaMemcpyToSymbol(constDirChans, dirChannels.data(), sizeof(mscclpp::channel::DirectChannel) * dirChannels.size());
|
||||
}
|
||||
|
||||
std::vector<void*> SendRecvTestEngine::getSendBuff() { return {devicePtrs_[0].get()}; }
|
||||
@@ -178,7 +181,10 @@ void* SendRecvTestEngine::getExpectedBuff() { return expectedBuff_.get(); }
|
||||
|
||||
void* SendRecvTestEngine::getRecvBuff() { return devicePtrs_[1].get(); }
|
||||
|
||||
void* SendRecvTestEngine::getScratchBuff() { return nullptr; }
|
||||
|
||||
std::shared_ptr<BaseTestEngine> getTestEngine(const TestArgs& args) {
|
||||
return std::make_shared<SendRecvTestEngine>(args);
|
||||
}
|
||||
|
||||
std::shared_ptr<BaseTestColl> getTestColl() { return std::make_shared<SendRecvTestColl>(); }
|
||||
|
||||
Reference in New Issue
Block a user