working -- at least for single node

This commit is contained in:
Saeed Maleki
2023-05-12 20:21:58 +00:00
parent 113473a116
commit 2691784b88
5 changed files with 65 additions and 35 deletions

View File

@@ -18,9 +18,15 @@ Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
rankToHash_[bootstrap->getRank()] = hostHash;
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
CUDATHROW(cudaStreamCreateWithFlags(&ipcStream_, cudaStreamNonBlocking));
}
Communicator::Impl::~Impl() { ibContexts_.clear(); }
Communicator::Impl::~Impl() {
ibContexts_.clear();
cudaStreamDestroy(ipcStream_);
}
IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) {
// Find IB context or create it
@@ -34,6 +40,8 @@ IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) {
}
}
cudaStream_t Communicator::Impl::getIpcStream() { return ipcStream_; }
MSCCLPP_API_CPP Communicator::~Communicator() = default;
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap)
@@ -95,7 +103,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage);
}
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag);
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag, pimpl->getIpcStream());
conn = cudaIpcConn;
INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created",
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank,

View File

@@ -30,11 +30,10 @@ int ConnectionBase::tag() { return tag_; }
// CudaIpcConnection
CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag) {
CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
}
CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag, cudaStream_t stream)
: ConnectionBase(remoteRank, tag), stream_(stream) {}
CudaIpcConnection::~CudaIpcConnection() { cudaStreamDestroy(stream); }
CudaIpcConnection::~CudaIpcConnection() {}
Transport CudaIpcConnection::transport() { return Transport::CudaIpc; }
@@ -48,14 +47,14 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register
char* dstPtr = (char*)dst.data();
char* srcPtr = (char*)src.data();
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream));
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream_));
INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size);
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
}
void CudaIpcConnection::flush() {
CUDATHROW(cudaStreamSynchronize(stream));
CUDATHROW(cudaStreamSynchronize(stream_));
// npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
}

View File

@@ -1,6 +1,8 @@
#ifndef MSCCL_COMMUNICATOR_HPP_
#define MSCCL_COMMUNICATOR_HPP_
#include <cuda_runtime.h>
#include <memory>
#include <mscclpp/core.hpp>
#include <mscclpp/proxy.hpp>
@@ -17,6 +19,7 @@ struct Communicator::Impl {
std::vector<std::shared_ptr<ConnectionBase>> connections_;
std::vector<std::shared_ptr<Setuppable>> toSetup_;
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
cudaStream_t ipcStream_;
std::shared_ptr<BaseBootstrap> bootstrap_;
std::vector<uint64_t> rankToHash_;
@@ -25,6 +28,7 @@ struct Communicator::Impl {
~Impl();
IbCtx* getIbContext(Transport ibTransport);
cudaStream_t getIpcStream();
};
} // namespace mscclpp

View File

@@ -27,10 +27,10 @@ class ConnectionBase : public Connection, public Setuppable {
};
class CudaIpcConnection : public ConnectionBase {
cudaStream_t stream;
cudaStream_t stream_;
public:
CudaIpcConnection(int remoteRank, int tag);
CudaIpcConnection(int remoteRank, int tag, cudaStream_t stream);
~CudaIpcConnection();

View File

@@ -51,18 +51,22 @@ static double getTime(void)
}
__global__ void kernel(int r, int nranks, mscclpp::DeviceProxyFifo fifo, mscclpp::DeviceEpoch::DeviceHandle* handles)
__global__ void kernel(int r, int nranks, mscclpp::DeviceProxyFifo fifo, mscclpp::DeviceEpoch::DeviceHandle* handles, int handleIndex)
{
int tid = threadIdx.x;
if (tid != r)
handles[tid].epochIncrement();
__syncthreads();
uint64_t tail;
if (tid == 0){
mscclpp::ProxyTrigger trigger;
trigger.fst = 1;
fifo.push(trigger);
trigger.fst = handleIndex;
tail = fifo.push(trigger);
}
if (tid != r)
handles[tid].wait();
// if (tid == 0)
// while(*(volatile uint64_t*)fifo.tailReplica < tail) {};
}
int rankToLocalRank(int rank)
@@ -121,15 +125,15 @@ public:
}
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
if (triggerRaw.fst == 1) {
if (triggerRaw.fst > 0) {
int dataSizePerRank = dataSize / world_size;
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
continue;
}
connections[r]->write(remoteMemories[r], rank*dataSizePerRank, localMemory, rank*dataSizePerRank, dataSizePerRank);
deviceEpochs[r]->signal();
connections[r]->flush();
for (int r = 1; r < world_size; ++r) {
int nghr = (rank + r) % world_size;
connections[nghr]->write(remoteMemories[nghr], rank*dataSizePerRank, localMemory, rank*dataSizePerRank, dataSizePerRank);
if (triggerRaw.fst == 1)
deviceEpochs1[nghr]->signal();
else
deviceEpochs2[nghr]->signal();
}
}
return mscclpp::ProxyHandlerResult::FlushFifoTailAndContinue;
@@ -138,7 +142,8 @@ public:
std::vector<mscclpp::RegisteredMemory> remoteMemories;
mscclpp::RegisteredMemory localMemory;
std::vector<std::shared_ptr<mscclpp::HostEpoch>> hostEpochs;
std::vector<std::shared_ptr<mscclpp::DeviceEpoch>> deviceEpochs;
std::vector<std::shared_ptr<mscclpp::DeviceEpoch>> deviceEpochs1;
std::vector<std::shared_ptr<mscclpp::DeviceEpoch>> deviceEpochs2;
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
int dataSize;
};
@@ -156,7 +161,8 @@ void setupProxyService(mscclpp::Communicator& comm, MyProxyService& proxyService
for (int r = 0; r < world_size; ++r) {
if (r == rank){
proxyService.hostEpochs.emplace_back(nullptr);
proxyService.deviceEpochs.emplace_back(nullptr);
proxyService.deviceEpochs1.emplace_back(nullptr);
proxyService.deviceEpochs2.emplace_back(nullptr);
continue;
}
mscclpp::Transport transport;
@@ -172,7 +178,8 @@ void setupProxyService(mscclpp::Communicator& comm, MyProxyService& proxyService
} else {
proxyService.hostEpochs.emplace_back(std::make_shared<mscclpp::HostEpoch>(comm, proxyService.connections[r]));
}
proxyService.deviceEpochs.emplace_back(std::make_shared<mscclpp::DeviceEpoch>(comm, proxyService.connections[r]));
proxyService.deviceEpochs1.emplace_back(std::make_shared<mscclpp::DeviceEpoch>(comm, proxyService.connections[r]));
proxyService.deviceEpochs2.emplace_back(std::make_shared<mscclpp::DeviceEpoch>(comm, proxyService.connections[r]));
comm.sendMemoryOnSetup(proxyService.localMemory, r, 0);
remoteMemories[r] = comm.recvMemoryOnSetup(r, 0);
@@ -267,17 +274,26 @@ int main(int argc, char* argv[])
printf("Testing the correctness of AllGather implementation\n");
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
mscclpp::DeviceEpoch::DeviceHandle* deviceHandles;
mscclpp::DeviceEpoch::DeviceHandle* deviceHandles1;
mscclpp::DeviceEpoch::DeviceHandle* deviceHandles2;
CUDACHECK(cudaMalloc(&deviceHandles, sizeof(mscclpp::DeviceEpoch::DeviceHandle) * world_size));
CUDACHECK(cudaMalloc(&deviceHandles1, sizeof(mscclpp::DeviceEpoch::DeviceHandle) * world_size));
for (int i = 0; i < world_size; ++i) {
if (i == rank)
continue;
auto handle = proxyService.deviceEpochs[i]->deviceHandle();
CUDACHECK(cudaMemcpy(&deviceHandles[i], &handle, sizeof(mscclpp::DeviceEpoch::DeviceHandle), cudaMemcpyHostToDevice));
auto handle = proxyService.deviceEpochs1[i]->deviceHandle();
CUDACHECK(cudaMemcpy(&deviceHandles1[i], &handle, sizeof(mscclpp::DeviceEpoch::DeviceHandle), cudaMemcpyHostToDevice));
}
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles);
CUDACHECK(cudaMalloc(&deviceHandles2, sizeof(mscclpp::DeviceEpoch::DeviceHandle) * world_size));
for (int i = 0; i < world_size; ++i) {
if (i == rank)
continue;
auto handle = proxyService.deviceEpochs2[i]->deviceHandle();
CUDACHECK(cudaMemcpy(&deviceHandles2[i], &handle, sizeof(mscclpp::DeviceEpoch::DeviceHandle), cudaMemcpyHostToDevice));
}
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles1, 1);
CUDACHECK(cudaStreamSynchronize(stream));
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));
@@ -302,13 +318,14 @@ int main(int argc, char* argv[])
bootstrap->barrier();
t0 = getTime();
for (int i = 0; i < iterwithoutcudagraph; ++i) {
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles);
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles1, 1);
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles2, 2);
}
CUDACHECK(cudaStreamSynchronize(stream));
bootstrap->barrier();
t1 = getTime();
ms = (t1 - t0) * 1000.0;
time_in_us = ms * 1000. / (float)iterwithoutcudagraph;
time_in_us = ms * 1000. / (float)iterwithoutcudagraph / 2;
printf("No Graph %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
@@ -320,7 +337,8 @@ int main(int argc, char* argv[])
cudaGraphExec_t instance;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
for (int i = 0; i < cudagraphiter; ++i) {
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles);
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles1, 1);
kernel<<<1, world_size, 0, stream>>>(rank, world_size, fifo, deviceHandles2, 2);
}
cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
@@ -348,9 +366,10 @@ int main(int argc, char* argv[])
t1 = getTime();
ms = (t1 - t0) * 1000.0;
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter / 2;
if (rank == 0)
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
bootstrap->barrier();
if (rank == 0)