mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 06:44:40 +00:00
FIFO improvements (#557)
* Revert `MSCCLPP_FIFO_USE_TAIL_REPLICA=1` back to the default. * Optimize `FifoDeviceHandle`. * Do not use `cudaHostAllocWriteCombined` that increases latency. * Pin host memory for `Host2DeviceSemaphore::outboundSemaphore_`. * Fix proxy NUMA binding issues. * Prevent graph capture inside proxy threads. * Now `CudaIpcConnection` skips stream sync when unnecessary. * Now any type of connection needs to hold a shared pointer to the context for memory safety. * Now a context should be always managed by a shared pointer for memory safety. * Minor docs & interface improvements. * Minor fix in `mscclpp-test` correctness test.
This commit is contained in:
@@ -11,7 +11,7 @@ namespace mscclpp {
|
||||
Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context)
|
||||
: bootstrap_(bootstrap) {
|
||||
if (!context) {
|
||||
context_ = std::make_shared<Context>();
|
||||
context_ = Context::create();
|
||||
} else {
|
||||
context_ = context;
|
||||
}
|
||||
|
||||
@@ -37,13 +37,13 @@ std::string Connection::getTransportName() const {
|
||||
TransportNames[static_cast<int>(this->remoteTransport())];
|
||||
}
|
||||
|
||||
int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize; }
|
||||
int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
|
||||
|
||||
// CudaIpcConnection
|
||||
|
||||
CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint,
|
||||
std::shared_ptr<CudaStreamWithFlags> stream)
|
||||
: Connection(localEndpoint.maxWriteQueueSize()), stream_(stream) {
|
||||
CudaIpcConnection::CudaIpcConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint,
|
||||
std::shared_ptr<CudaIpcStream> stream)
|
||||
: Connection(context, localEndpoint.maxWriteQueueSize()), stream_(stream) {
|
||||
if (localEndpoint.transport() != Transport::CudaIpc) {
|
||||
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
|
||||
}
|
||||
@@ -76,9 +76,8 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register
|
||||
char* dstPtr = (char*)dst.data();
|
||||
char* srcPtr = (char*)src.data();
|
||||
|
||||
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);
|
||||
stream_->memcpyD2D(dstPtr + dstOffset, srcPtr + srcOffset, size);
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, *stream_));
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size);
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_CONN_CUDA_IPC_WRITE_EXIT)
|
||||
@@ -96,9 +95,8 @@ void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset,
|
||||
*src = newValue;
|
||||
uint64_t* dstPtr = reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(dst.data()) + dstOffset);
|
||||
|
||||
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);
|
||||
stream_->memcpyH2D(dstPtr + dstOffset, src, sizeof(uint64_t));
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr, src, sizeof(uint64_t), cudaMemcpyHostToDevice, *stream_));
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
|
||||
newValue);
|
||||
|
||||
@@ -116,10 +114,8 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection flush: timeout is not supported, ignored");
|
||||
}
|
||||
|
||||
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);
|
||||
stream_->sync();
|
||||
|
||||
AvoidCudaGraphCaptureGuard guard;
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(*stream_));
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection flushing connection");
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_CONN_CUDA_IPC_FLUSH_EXIT)
|
||||
@@ -129,16 +125,16 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
|
||||
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context)
|
||||
: Connection(localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize()
|
||||
: EndpointConfig::DefaultMaxCqSize),
|
||||
IBConnection::IBConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint)
|
||||
: Connection(context, localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize()
|
||||
: EndpointConfig::DefaultMaxCqSize),
|
||||
transport_(localEndpoint.transport()),
|
||||
remoteTransport_(remoteEndpoint.transport()),
|
||||
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
|
||||
qp = getImpl(localEndpoint)->ibQp_;
|
||||
qp->rtr(getImpl(remoteEndpoint)->ibQpInfo_);
|
||||
qp->rts();
|
||||
dummyAtomicSourceMem_ = context.registerMemory(dummyAtomicSource_.get(), sizeof(uint64_t), transport_);
|
||||
qp_ = getImpl(localEndpoint)->ibQp_;
|
||||
qp_->rtr(getImpl(remoteEndpoint)->ibQpInfo_);
|
||||
qp_->rts();
|
||||
dummyAtomicSourceMem_ = context->registerMemory(dummyAtomicSource_.get(), sizeof(uint64_t), transport_);
|
||||
validateTransport(dummyAtomicSourceMem_, transport_);
|
||||
dstTransportInfo_ = getImpl(dummyAtomicSourceMem_)->getTransportInfo(transport_);
|
||||
INFO(MSCCLPP_NET, "IB connection via %s created", getIBDeviceName(transport_).c_str());
|
||||
@@ -169,10 +165,10 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
|
||||
auto dstMrInfo = dstTransportInfo.ibMrInfo;
|
||||
auto srcMr = srcTransportInfo.ibMr;
|
||||
|
||||
qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset,
|
||||
/*signaled=*/true);
|
||||
qp_->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset,
|
||||
/*signaled=*/true);
|
||||
|
||||
qp->postSend();
|
||||
qp_->postSend();
|
||||
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset,
|
||||
(uint8_t*)dstMrInfo.addr + dstOffset, size);
|
||||
|
||||
@@ -197,9 +193,9 @@ void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint6
|
||||
uint64_t oldValue = *src;
|
||||
*src = newValue;
|
||||
|
||||
qp->stageAtomicAdd(dstTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset, newValue - oldValue, /*signaled=*/true);
|
||||
qp_->stageAtomicAdd(dstTransportInfo_.ibMr, dstMrInfo, /*wrId=*/0, dstOffset, newValue - oldValue, /*signaled=*/true);
|
||||
|
||||
qp->postSend();
|
||||
qp_->postSend();
|
||||
INFO(MSCCLPP_NET, "IBConnection atomic Write: from %p to %p, %lu -> %lu", src, (uint8_t*)dstMrInfo.addr + dstOffset,
|
||||
oldValue, newValue);
|
||||
|
||||
@@ -214,20 +210,20 @@ void IBConnection::flush(int64_t timeoutUsec) {
|
||||
#endif
|
||||
|
||||
Timer timer;
|
||||
while (qp->getNumCqItems()) {
|
||||
int wcNum = qp->pollCq();
|
||||
while (qp_->getNumCqItems()) {
|
||||
int wcNum = qp_->pollCq();
|
||||
if (wcNum < 0) {
|
||||
throw mscclpp::IbError("pollCq failed: error no " + std::to_string(errno), errno);
|
||||
} else if (timeoutUsec >= 0) {
|
||||
auto elapsed = timer.elapsed();
|
||||
if (elapsed > timeoutUsec) {
|
||||
throw Error("pollCq timed out: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " +
|
||||
std::to_string(qp->getNumCqItems()) + " signals",
|
||||
std::to_string(qp_->getNumCqItems()) + " signals",
|
||||
ErrorCode::Timeout);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
int status = qp->getWcStatus(i);
|
||||
int status = qp_->getWcStatus(i);
|
||||
if (status != static_cast<int>(WsStatus::Success)) {
|
||||
throw mscclpp::IbError("a work item failed: status " + std::to_string(status), status);
|
||||
}
|
||||
@@ -242,9 +238,9 @@ void IBConnection::flush(int64_t timeoutUsec) {
|
||||
|
||||
// EthernetConnection
|
||||
|
||||
EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize,
|
||||
uint64_t recvBufferSize)
|
||||
: Connection(localEndpoint.maxWriteQueueSize()),
|
||||
EthernetConnection::EthernetConnection(std::shared_ptr<Context> context, Endpoint localEndpoint,
|
||||
Endpoint remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize)
|
||||
: Connection(context, localEndpoint.maxWriteQueueSize()),
|
||||
abortFlag_(0),
|
||||
sendBufferSize_(sendBufferSize),
|
||||
recvBufferSize_(recvBufferSize) {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "context.hpp"
|
||||
|
||||
#include <mscclpp/env.hpp>
|
||||
|
||||
#include "api.h"
|
||||
#include "connection.hpp"
|
||||
#include "debug.h"
|
||||
@@ -11,9 +13,35 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
CudaIpcStream::CudaIpcStream() : stream_(std::make_shared<CudaStreamWithFlags>()), dirty_(false) {}
|
||||
|
||||
void CudaIpcStream::setStreamIfNeeded() {
|
||||
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);
|
||||
}
|
||||
|
||||
void CudaIpcStream::memcpyD2D(void *dst, const void *src, size_t nbytes) {
|
||||
setStreamIfNeeded();
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyDeviceToDevice, *stream_));
|
||||
dirty_ = true;
|
||||
}
|
||||
|
||||
void CudaIpcStream::memcpyH2D(void *dst, const void *src, size_t nbytes) {
|
||||
setStreamIfNeeded();
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(dst, src, nbytes, cudaMemcpyHostToDevice, *stream_));
|
||||
dirty_ = true;
|
||||
}
|
||||
|
||||
void CudaIpcStream::sync() {
|
||||
setStreamIfNeeded();
|
||||
if (dirty_) {
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(*stream_));
|
||||
dirty_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
Context::Impl::Impl() {}
|
||||
|
||||
IbCtx* Context::Impl::getIbContext(Transport ibTransport) {
|
||||
IbCtx *Context::Impl::getIbContext(Transport ibTransport) {
|
||||
// Find IB context or create it
|
||||
auto it = ibContexts_.find(ibTransport);
|
||||
if (it == ibContexts_.end()) {
|
||||
@@ -29,7 +57,7 @@ MSCCLPP_API_CPP Context::Context() : pimpl_(std::make_unique<Impl>()) {}
|
||||
|
||||
MSCCLPP_API_CPP Context::~Context() = default;
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory Context::registerMemory(void* ptr, size_t size, TransportFlags transports) {
|
||||
MSCCLPP_API_CPP RegisteredMemory Context::registerMemory(void *ptr, size_t size, TransportFlags transports) {
|
||||
return RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(ptr, size, transports, *pimpl_));
|
||||
}
|
||||
|
||||
@@ -43,24 +71,25 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
|
||||
if (remoteEndpoint.transport() != Transport::CudaIpc) {
|
||||
throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaStreamWithFlags>());
|
||||
#else
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>());
|
||||
#else // !defined(MSCCLPP_DEVICE_HIP)
|
||||
if (pimpl_->ipcStreams_.empty()) {
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaStreamWithFlags>());
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>());
|
||||
}
|
||||
#endif
|
||||
conn = std::make_shared<CudaIpcConnection>(localEndpoint, remoteEndpoint, pimpl_->ipcStreams_.back());
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint,
|
||||
pimpl_->ipcStreams_.back());
|
||||
} else if (AllIBTransports.has(localEndpoint.transport())) {
|
||||
if (!AllIBTransports.has(remoteEndpoint.transport())) {
|
||||
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
conn = std::make_shared<IBConnection>(localEndpoint, remoteEndpoint, *this);
|
||||
conn = std::make_shared<IBConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else if (localEndpoint.transport() == Transport::Ethernet) {
|
||||
if (remoteEndpoint.transport() != Transport::Ethernet) {
|
||||
throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
conn = std::make_shared<EthernetConnection>(localEndpoint, remoteEndpoint);
|
||||
conn = std::make_shared<EthernetConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else {
|
||||
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ Env::Env()
|
||||
enableNcclFallback(readEnv<bool>("MSCCLPP_ENABLE_NCCL_FALLBACK", false)),
|
||||
disableChannelCache(readEnv<bool>("MSCCLPP_DISABLE_CHANNEL_CACHE", false)),
|
||||
forceDisableNvls(readEnv<bool>("MSCCLPP_FORCE_DISABLE_NVLS", false)),
|
||||
fifoUseTailReplica(readEnv<bool>("MSCCLPP_FIFO_USE_TAIL_REPLICA", false)) {}
|
||||
fifoUseTailReplica(readEnv<bool>("MSCCLPP_FIFO_USE_TAIL_REPLICA", true)) {}
|
||||
|
||||
std::shared_ptr<Env> env() {
|
||||
static std::shared_ptr<Env> globalEnv = std::shared_ptr<Env>(new Env());
|
||||
|
||||
57
src/fifo.cc
57
src/fifo.cc
@@ -4,6 +4,7 @@
|
||||
#include <mscclpp/env.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
|
||||
#include "api.h"
|
||||
#include "atomic.hpp"
|
||||
@@ -13,31 +14,40 @@ namespace mscclpp {
|
||||
struct Fifo::Impl {
|
||||
detail::UniqueGpuHostPtr<ProxyTrigger> triggers;
|
||||
detail::UniqueGpuPtr<uint64_t> head;
|
||||
std::shared_ptr<uint64_t> tailHost;
|
||||
detail::UniqueGpuPtr<uint64_t> tailReplica;
|
||||
const int size;
|
||||
|
||||
// The original tail of this fifo allocated on the host. If a tail replica is used
|
||||
// (when `env()->fifoUseTailReplica == true`), it always holds that *tailReplica <= *hostTail.
|
||||
std::shared_ptr<uint64_t> hostTail;
|
||||
|
||||
// for transferring fifo tail
|
||||
CudaStreamWithFlags stream;
|
||||
|
||||
Impl(int size)
|
||||
: triggers(detail::gpuCallocHostUnique<ProxyTrigger>(size)),
|
||||
head(detail::gpuCallocUnique<uint64_t>()),
|
||||
tailHost(env()->fifoUseTailReplica ? std::make_shared<uint64_t>(0) : detail::gpuCallocHostShared<uint64_t>()),
|
||||
tailReplica(env()->fifoUseTailReplica ? detail::gpuCallocUnique<uint64_t>() : nullptr),
|
||||
size(size),
|
||||
hostTail(env()->fifoUseTailReplica ? std::make_shared<uint64_t>(0) : detail::gpuCallocHostShared<uint64_t>()),
|
||||
stream(cudaStreamNonBlocking) {}
|
||||
size(size) {
|
||||
if (env()->fifoUseTailReplica) {
|
||||
stream.set(cudaStreamNonBlocking);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP Fifo::Fifo(int size) : pimpl(std::make_unique<Impl>(size)) {}
|
||||
MSCCLPP_API_CPP Fifo::Fifo(int size) {
|
||||
int device;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&device));
|
||||
int numaNode = getDeviceNumaNode(device);
|
||||
if (numaNode >= 0) {
|
||||
numaBind(numaNode);
|
||||
}
|
||||
pimpl_ = std::make_unique<Impl>(size);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Fifo::~Fifo() = default;
|
||||
|
||||
MSCCLPP_API_CPP ProxyTrigger Fifo::poll() {
|
||||
ProxyTrigger trigger;
|
||||
ProxyTrigger* ptr = &pimpl->triggers.get()[*(pimpl->hostTail) % pimpl->size];
|
||||
ProxyTrigger* ptr = &pimpl_->triggers.get()[*(pimpl_->tailHost) % pimpl_->size];
|
||||
// we are loading fst first. if fst is non-zero then snd is also valid
|
||||
trigger.fst = atomicLoad(&(ptr->fst), memoryOrderAcquire);
|
||||
trigger.snd = ptr->snd;
|
||||
@@ -45,39 +55,34 @@ MSCCLPP_API_CPP ProxyTrigger Fifo::poll() {
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Fifo::pop() {
|
||||
uint64_t curTail = *(pimpl->hostTail);
|
||||
atomicStore(&(pimpl->triggers.get()[curTail % pimpl->size].fst), uint64_t{0}, memoryOrderRelease);
|
||||
*(pimpl->hostTail) = curTail + 1;
|
||||
uint64_t curTail = *(pimpl_->tailHost);
|
||||
pimpl_->triggers.get()[curTail % pimpl_->size].fst = 0;
|
||||
*(pimpl_->tailHost) = curTail + 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Fifo::flushTail(bool sync) {
|
||||
MSCCLPP_API_CPP void Fifo::flushTail([[maybe_unused]] bool sync) {
|
||||
if (!env()->fifoUseTailReplica) {
|
||||
// Nothing to flush if the tail is not replicated.
|
||||
return;
|
||||
}
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
*(pimpl->tailReplica.get()) = *(pimpl->hostTail.get());
|
||||
#else // !defined(MSCCLPP_DEVICE_HIP)
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can
|
||||
// make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
|
||||
AvoidCudaGraphCaptureGuard cgcGuard;
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), pimpl->hostTail.get(), sizeof(uint64_t),
|
||||
cudaMemcpyHostToDevice, pimpl->stream));
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl_->tailReplica.get(), pimpl_->tailHost.get(), sizeof(uint64_t),
|
||||
cudaMemcpyHostToDevice, pimpl_->stream));
|
||||
if (sync) {
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(pimpl->stream));
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(pimpl_->stream));
|
||||
}
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Fifo::size() const { return pimpl->size; }
|
||||
MSCCLPP_API_CPP int Fifo::size() const { return pimpl_->size; }
|
||||
|
||||
MSCCLPP_API_CPP FifoDeviceHandle Fifo::deviceHandle() const {
|
||||
FifoDeviceHandle deviceHandle;
|
||||
deviceHandle.triggers = pimpl->triggers.get();
|
||||
deviceHandle.head = pimpl->head.get();
|
||||
deviceHandle.triggers = pimpl_->triggers.get();
|
||||
deviceHandle.head = pimpl_->head.get();
|
||||
// tailReplica refers to the original tail if `fifoUseTailReplica == false`.
|
||||
deviceHandle.tailReplica = env()->fifoUseTailReplica ? pimpl->tailReplica.get() : pimpl->hostTail.get();
|
||||
deviceHandle.size = pimpl->size;
|
||||
deviceHandle.tailReplica = env()->fifoUseTailReplica ? pimpl_->tailReplica.get() : pimpl_->tailHost.get();
|
||||
deviceHandle.size = pimpl_->size;
|
||||
return deviceHandle;
|
||||
}
|
||||
|
||||
|
||||
@@ -83,10 +83,10 @@ void* gpuCalloc(size_t bytes) {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void* gpuCallocHost(size_t bytes) {
|
||||
void* gpuCallocHost(size_t bytes, unsigned int flags) {
|
||||
AvoidCudaGraphCaptureGuard cgcGuard;
|
||||
void* ptr;
|
||||
MSCCLPP_CUDATHROW(cudaHostAlloc(&ptr, bytes, cudaHostAllocMapped | cudaHostAllocWriteCombined));
|
||||
MSCCLPP_CUDATHROW(cudaHostAlloc(&ptr, bytes, flags));
|
||||
::memset(ptr, 0, bytes);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
@@ -16,10 +16,12 @@
|
||||
namespace mscclpp {
|
||||
|
||||
class CudaIpcConnection : public Connection {
|
||||
std::shared_ptr<CudaStreamWithFlags> stream_;
|
||||
private:
|
||||
std::shared_ptr<CudaIpcStream> stream_;
|
||||
|
||||
public:
|
||||
CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, std::shared_ptr<CudaStreamWithFlags> stream);
|
||||
CudaIpcConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint,
|
||||
std::shared_ptr<CudaIpcStream> stream);
|
||||
|
||||
Transport transport() const override;
|
||||
|
||||
@@ -33,15 +35,16 @@ class CudaIpcConnection : public Connection {
|
||||
};
|
||||
|
||||
class IBConnection : public Connection {
|
||||
private:
|
||||
Transport transport_;
|
||||
Transport remoteTransport_;
|
||||
IbQp* qp;
|
||||
IbQp* qp_;
|
||||
std::unique_ptr<uint64_t> dummyAtomicSource_; // not used anywhere but IB needs a source
|
||||
RegisteredMemory dummyAtomicSourceMem_;
|
||||
mscclpp::TransportInfo dstTransportInfo_;
|
||||
|
||||
public:
|
||||
IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context);
|
||||
IBConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint);
|
||||
|
||||
Transport transport() const override;
|
||||
|
||||
@@ -55,6 +58,7 @@ class IBConnection : public Connection {
|
||||
};
|
||||
|
||||
class EthernetConnection : public Connection {
|
||||
private:
|
||||
std::unique_ptr<Socket> sendSocket_;
|
||||
std::unique_ptr<Socket> recvSocket_;
|
||||
std::thread threadRecvMessages_;
|
||||
@@ -64,9 +68,12 @@ class EthernetConnection : public Connection {
|
||||
std::vector<char> sendBuffer_;
|
||||
std::vector<char> recvBuffer_;
|
||||
|
||||
void recvMessages();
|
||||
void sendMessage();
|
||||
|
||||
public:
|
||||
EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize = 256 * 1024 * 1024,
|
||||
uint64_t recvBufferSize = 256 * 1024 * 1024);
|
||||
EthernetConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint,
|
||||
uint64_t sendBufferSize = 256 * 1024 * 1024, uint64_t recvBufferSize = 256 * 1024 * 1024);
|
||||
|
||||
~EthernetConnection();
|
||||
|
||||
@@ -79,11 +86,6 @@ class EthernetConnection : public Connection {
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
|
||||
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
|
||||
private:
|
||||
void recvMessages();
|
||||
|
||||
void sendMessage();
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -13,15 +13,34 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
class CudaIpcStream {
|
||||
private:
|
||||
std::shared_ptr<CudaStreamWithFlags> stream_;
|
||||
bool dirty_;
|
||||
|
||||
void setStreamIfNeeded();
|
||||
|
||||
public:
|
||||
CudaIpcStream();
|
||||
|
||||
void memcpyD2D(void *dst, const void *src, size_t nbytes);
|
||||
|
||||
void memcpyH2D(void *dst, const void *src, size_t nbytes);
|
||||
|
||||
void sync();
|
||||
|
||||
operator cudaStream_t() const { return *stream_; }
|
||||
};
|
||||
|
||||
struct Context::Impl {
|
||||
std::vector<std::shared_ptr<Connection>> connections_;
|
||||
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
|
||||
std::vector<std::shared_ptr<CudaStreamWithFlags>> ipcStreams_;
|
||||
std::vector<std::shared_ptr<CudaIpcStream>> ipcStreams_;
|
||||
CUmemGenericAllocationHandle mcHandle_;
|
||||
|
||||
Impl();
|
||||
|
||||
IbCtx* getIbContext(Transport ibTransport);
|
||||
IbCtx *getIbContext(Transport ibTransport);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -18,12 +18,19 @@ MSCCLPP_API_CPP PortChannel::PortChannel(SemaphoreId semaphoreId, std::shared_pt
|
||||
std::shared_ptr<Proxy> proxy, MemoryId dst, MemoryId src)
|
||||
: BasePortChannel(semaphoreId, semaphore, proxy), dst_(dst), src_(src) {}
|
||||
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService(size_t fifoSize)
|
||||
: proxy_(std::make_shared<Proxy>([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
|
||||
[&]() { bindThread(); }, fifoSize)) {
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
||||
int deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
||||
auto initFunc = [cudaDevice, deviceNumaNode]() {
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice));
|
||||
if (deviceNumaNode >= 0) {
|
||||
numaBind(deviceNumaNode);
|
||||
INFO(MSCCLPP_INIT, "NUMA node of ProxyService proxy thread is set to %d", deviceNumaNode);
|
||||
}
|
||||
};
|
||||
auto handlerFunc = [&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); };
|
||||
proxy_ = std::make_shared<Proxy>(handlerFunc, initFunc, fifoSize);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
|
||||
@@ -58,49 +65,44 @@ MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_->start(); }
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::bindThread() {
|
||||
if (deviceNumaNode >= 0) {
|
||||
numaBind(deviceNumaNode);
|
||||
INFO(MSCCLPP_INIT, "NUMA node of ProxyService proxy thread is set to %d", deviceNumaNode);
|
||||
}
|
||||
}
|
||||
|
||||
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
|
||||
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
|
||||
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.semaphoreId];
|
||||
|
||||
auto result = ProxyHandlerResult::Continue;
|
||||
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();
|
||||
auto& numRequests = inflightRequests_[semaphore->connection()];
|
||||
|
||||
if (trigger->fields.type & TriggerData) {
|
||||
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
|
||||
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
|
||||
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
|
||||
trigger->fields.size);
|
||||
inflightRequests[semaphore->connection()]++;
|
||||
numRequests++;
|
||||
}
|
||||
|
||||
if (trigger->fields.type & TriggerFlag) {
|
||||
semaphore->signal();
|
||||
inflightRequests[semaphore->connection()]++;
|
||||
numRequests++;
|
||||
}
|
||||
|
||||
if (trigger->fields.type & TriggerSync ||
|
||||
(maxWriteQueueSize != -1 && inflightRequests[semaphore->connection()] > maxWriteQueueSize)) {
|
||||
if (((trigger->fields.type & TriggerSync) && numRequests > 0) ||
|
||||
(maxWriteQueueSize != -1 && numRequests > maxWriteQueueSize)) {
|
||||
semaphore->connection()->flush();
|
||||
result = ProxyHandlerResult::FlushFifoTailAndContinue;
|
||||
inflightRequests[semaphore->connection()] = 0;
|
||||
numRequests = 0;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BasePortChannel::DeviceHandle BasePortChannel::deviceHandle() const {
|
||||
return BasePortChannel::DeviceHandle(semaphoreId_, semaphore_->deviceHandle(), proxy_->fifo().deviceHandle());
|
||||
return BasePortChannel::DeviceHandle(semaphoreId_, semaphore_->deviceHandle(), proxy_->fifo()->deviceHandle());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP PortChannel::DeviceHandle PortChannel::deviceHandle() const {
|
||||
return PortChannel::DeviceHandle(semaphoreId_, semaphore_->deviceHandle(), proxy_->fifo().deviceHandle(), dst_, src_);
|
||||
return PortChannel::DeviceHandle(semaphoreId_, semaphore_->deviceHandle(), proxy_->fifo()->deviceHandle(), dst_,
|
||||
src_);
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
78
src/proxy.cc
78
src/proxy.cc
@@ -4,6 +4,7 @@
|
||||
#include <atomic>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <thread>
|
||||
@@ -12,53 +13,61 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
const int ProxyStopCheckPeriod = 1000;
|
||||
constexpr int ProxyStopCheckPeriod = 1000;
|
||||
|
||||
// Unless explicitly requested, a flush of the tail to device memory is triggered for every ProxyFlushPeriod.
|
||||
// As long as the FIFO size is large enough, having a stale tail is not a problem.
|
||||
const int ProxyFlushPeriod = 4;
|
||||
constexpr int ProxyFlushPeriod = 4;
|
||||
|
||||
struct Proxy::Impl {
|
||||
ProxyHandler handler;
|
||||
std::function<void()> threadInit;
|
||||
Fifo fifo;
|
||||
std::shared_ptr<Fifo> fifo;
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize)
|
||||
: handler(handler), threadInit(threadInit), fifo(fifoSize), running(false) {}
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit, int fifoSize)
|
||||
: handler(handler), threadInit(threadInit), fifo(std::make_shared<Fifo>(fifoSize)), running(false) {}
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize) {
|
||||
pimpl = std::make_unique<Impl>(handler, threadInit, fifoSize);
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit, int fifoSize) {
|
||||
pimpl_ = std::make_unique<Impl>(handler, threadInit, fifoSize);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, size_t fifoSize)
|
||||
: Proxy(
|
||||
handler, [] {}, fifoSize) {}
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, int fifoSize) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
int deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
||||
auto initFunc = [cudaDevice, deviceNumaNode]() {
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice));
|
||||
if (deviceNumaNode >= 0) {
|
||||
numaBind(deviceNumaNode);
|
||||
}
|
||||
};
|
||||
pimpl_ = std::make_unique<Impl>(handler, initFunc, fifoSize);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::~Proxy() {
|
||||
if (pimpl) {
|
||||
if (pimpl_) {
|
||||
stop();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::start() {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
pimpl_->running = true;
|
||||
pimpl_->service = std::thread([this] {
|
||||
// never capture in a proxy thread
|
||||
auto mode = cudaStreamCaptureModeRelaxed;
|
||||
MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
pimpl->running = true;
|
||||
pimpl->service = std::thread([this, cudaDevice] {
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice));
|
||||
pimpl_->threadInit();
|
||||
|
||||
pimpl->threadInit();
|
||||
|
||||
ProxyHandler handler = this->pimpl->handler;
|
||||
Fifo& fifo = this->pimpl->fifo;
|
||||
std::atomic_bool& running = this->pimpl->running;
|
||||
ProxyHandler handler = this->pimpl_->handler;
|
||||
auto fifo = this->pimpl_->fifo;
|
||||
std::atomic_bool& running = this->pimpl_->running;
|
||||
ProxyTrigger trigger;
|
||||
|
||||
int flushPeriod = std::min(fifo.size(), ProxyFlushPeriod);
|
||||
int flushPeriod = std::min(fifo->size(), ProxyFlushPeriod);
|
||||
|
||||
int runCnt = ProxyStopCheckPeriod;
|
||||
uint64_t flushCnt = 0;
|
||||
@@ -70,21 +79,20 @@ MSCCLPP_API_CPP void Proxy::start() {
|
||||
}
|
||||
}
|
||||
// Poll to see if we are ready to send anything
|
||||
trigger = fifo.poll();
|
||||
trigger = fifo->poll();
|
||||
if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
continue; // there is one in progress
|
||||
}
|
||||
trigger.snd ^= ((uint64_t)1 << (uint64_t)63); // this is where the last bit of snd is reverted.
|
||||
trigger.snd ^= (uint64_t{1} << uint64_t{63}); // this is where the last bit of snd is reverted.
|
||||
|
||||
ProxyHandlerResult result = handler(trigger);
|
||||
|
||||
// Send completion: reset only the high 64 bits
|
||||
fifo.pop();
|
||||
fifo->pop();
|
||||
// Flush the tail to device memory. This is either triggered every flushPeriod to make sure that the fifo can make
|
||||
// progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
|
||||
if ((++flushCnt % flushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) {
|
||||
// TODO: relocate this check: || (trigger.fields.type & mscclppSync)
|
||||
fifo.flushTail();
|
||||
fifo->flushTail();
|
||||
}
|
||||
|
||||
if (result == ProxyHandlerResult::Stop) {
|
||||
@@ -93,23 +101,17 @@ MSCCLPP_API_CPP void Proxy::start() {
|
||||
}
|
||||
|
||||
// make sure the tail is flushed before we shut the proxy
|
||||
fifo.flushTail(/*sync=*/true);
|
||||
// TODO: do these need to run?
|
||||
// bool isP2pProxy = (proxyState->ibContext == nullptr);
|
||||
// if (isP2pProxy) {
|
||||
// cudaStream_t p2pStream = proxyState->p2pStream;
|
||||
// PROXYCUDACHECK(cudaStreamSynchronize(p2pStream));
|
||||
// }
|
||||
fifo->flushTail(/*sync=*/true);
|
||||
});
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::stop() {
|
||||
pimpl->running = false;
|
||||
if (pimpl->service.joinable()) {
|
||||
pimpl->service.join();
|
||||
pimpl_->running = false;
|
||||
if (pimpl_->service.joinable()) {
|
||||
pimpl_->service.join();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Fifo& Proxy::fifo() { return pimpl->fifo; }
|
||||
MSCCLPP_API_CPP std::shared_ptr<Fifo> Proxy::fifo() { return pimpl_->fifo; }
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -29,7 +29,7 @@ static detail::UniqueGpuPtr<uint64_t> createGpuSemaphoreId() {
|
||||
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection)
|
||||
: BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), std::make_unique<uint64_t>()),
|
||||
: BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), detail::gpuCallocHostUnique<uint64_t>()),
|
||||
connection_(connection) {
|
||||
INFO(MSCCLPP_INIT, "Creating a Host2Device semaphore for %s transport from %d to %d",
|
||||
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
|
||||
|
||||
Reference in New Issue
Block a user