Add exception class for mscclpp (#67)

Add exception class for mscclpp
This commit is contained in:
Binyang2014
2023-05-06 16:27:25 +08:00
committed by GitHub
parent 669c67b3de
commit 8650dbaff8
22 changed files with 381 additions and 211 deletions

View File

@@ -121,7 +121,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc)
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc errors.cc)
ifneq ($(NPKIT), 0)
LIBSRCS += $(addprefix src/misc/,npkit.cc)
endif
@@ -135,7 +135,7 @@ HEADERS := $(wildcard src/include/*.h)
CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp errors.hpp
INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%)
LIBNAME := libmscclpp.so

View File

@@ -194,13 +194,15 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector<
MSCCLPPTHROW(mscclppSocketClose(&sock));
if (this->nRanks_ != info.nRanks) {
throw std::runtime_error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) +
" : " + std::to_string(info.nRanks));
throw mscclpp::Error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) + " : " +
std::to_string(info.nRanks),
mscclppInternalError);
}
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
throw std::runtime_error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " +
std::to_string(this->nRanks_) + " has already checked in");
throw mscclpp::Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) +
" has already checked in",
mscclppInternalError);
}
// Save the connection handle for that rank
@@ -269,16 +271,17 @@ void Bootstrap::Impl::netInit(std::string ipPortPair)
if (!ipPortPair.empty()) {
mscclppSocketAddress remoteAddr;
if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) {
throw std::runtime_error(
"Invalid ipPortPair, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
throw mscclpp::Error(
"Invalid ipPortPair, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>",
mscclppInvalidArgument);
}
if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
throw std::runtime_error("NET/Socket : No usable listening interface found");
throw mscclpp::Error("NET/Socket : No usable listening interface found", mscclppInternalError);
}
} else {
int ret = mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
if (ret <= 0) {
throw std::runtime_error("Bootstrap : no socket interface found");
throw mscclpp::Error("Bootstrap : no socket interface found", mscclppInternalError);
}
}
@@ -390,8 +393,9 @@ void Bootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size)
int recvSize;
MSCCLPPTHROW(mscclppSocketRecv(sock, &recvSize, sizeof(int)));
if (recvSize > size) {
throw std::runtime_error("Message truncated : received " + std::to_string(recvSize) + " bytes instead of " +
std::to_string(size));
throw mscclpp::Error("Message truncated : received " + std::to_string(recvSize) + " bytes instead of " +
std::to_string(size),
mscclppInternalError);
}
MSCCLPPTHROW(mscclppSocketRecv(sock, data, std::min(recvSize, size)));
}
@@ -1058,4 +1062,4 @@ mscclppResult_t bootstrapAbort(void* commState)
free(state->peerProxyAddresses);
free(state);
return mscclppSuccess;
}
}

View File

@@ -1,14 +1,16 @@
#include "channel.hpp"
#include "utils.h"
#include "checks.hpp"
#include "api.h"
#include "checks.hpp"
#include "debug.h"
#include "utils.h"
namespace mscclpp {
namespace channel {
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator),
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator)
: communicator_(communicator),
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); })
{
int cudaDevice;
CUDATHROW(cudaGetDevice(&cudaDevice));
MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode));
@@ -23,4 +25,4 @@ MSCCLPP_API_CPP void DeviceChannelService::bindThread()
}
} // namespace channel
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -59,8 +59,9 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
struct MemorySender : public Setuppable
{
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
MemorySender(RegisteredMemory memory, int remoteRank, int tag) : memory_(memory), remoteRank_(remoteRank), tag_(tag)
{
}
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
{
@@ -79,8 +80,9 @@ MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, in
struct MemoryReceiver : public Setuppable
{
MemoryReceiver(int remoteRank, int tag)
: remoteRank_(remoteRank), tag_(tag) {}
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag)
{
}
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
{
@@ -112,7 +114,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"
<< " != " << pimpl->bootstrap_->getRank() << "(" << std::hex
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
throw std::runtime_error(ss.str());
throw mscclpp::Error(ss.str(), mscclppInternalError);
}
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag);
conn = cudaIpcConn;
@@ -126,7 +128,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
} else {
throw std::runtime_error("Unsupported transport");
throw mscclpp::Error("Unsupported transport", mscclppInvalidArgument);
}
pimpl->connections_.push_back(conn);
addSetup(conn);

View File

@@ -1,17 +1,17 @@
#include <algorithm>
#include "connection.hpp"
#include "checks.hpp"
#include "infiniband/verbs.h"
#include "npkit/npkit.h"
#include "registered_memory.hpp"
#include "utils.hpp"
#include <algorithm>
namespace mscclpp {
void validateTransport(RegisteredMemory mem, Transport transport)
{
if (!mem.transports().has(transport)) {
throw std::runtime_error("mem does not support transport");
throw Error("RegisteredMemory does not support transport", mscclppInvalidArgument);
}
}
@@ -24,11 +24,19 @@ std::shared_ptr<RegisteredMemory::Impl> Connection::getRegisteredMemoryImpl(Regi
// ConnectionBase
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag)
{
}
int ConnectionBase::remoteRank() { return remoteRank_; }
int ConnectionBase::remoteRank()
{
return remoteRank_;
}
int ConnectionBase::tag() { return tag_; }
int ConnectionBase::tag()
{
return tag_;
}
// CudaIpcConnection
@@ -99,11 +107,11 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
if (dstTransportInfo.ibLocal) {
throw std::runtime_error("dst is local, which is not supported");
throw Error("dst is local, which is not supported", mscclppInvalidArgument);
}
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport());
if (!srcTransportInfo.ibLocal) {
throw std::runtime_error("src is remote, which is not supported");
throw Error("src is remote, which is not supported", mscclppInvalidArgument);
}
auto dstMrInfo = dstTransportInfo.ibMrInfo;
@@ -113,7 +121,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
/*signaled=*/true);
numSignaledSends++;
qp->postSend();
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, (uint8_t*)dstMrInfo.addr + dstOffset, size);
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset,
(uint8_t*)dstMrInfo.addr + dstOffset, size);
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
}
@@ -123,16 +132,19 @@ void IBConnection::flush()
while (numSignaledSends) {
int wcNum = qp->pollCq();
if (wcNum < 0) {
throw std::runtime_error("pollCq failed: error no " + std::to_string(errno));
throw mscclpp::IbError("pollCq failed: error no " + std::to_string(errno), errno);
}
auto elapsed = timer.elapsed();
if (elapsed > MSCCLPP_POLLING_WAIT)
throw std::runtime_error("pollCq is stuck: waited for " + std::to_string(elapsed) + " seconds. Expected " + std::to_string(numSignaledSends) + " signals");
if (elapsed > MSCCLPP_POLLING_WAIT) {
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed) + " seconds. Expected " +
std::to_string(numSignaledSends) + " signals",
mscclppInternalError);
}
for (int i = 0; i < wcNum; ++i) {
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
if (wc->status != IBV_WC_SUCCESS) {
throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status));
throw mscclpp::IbError("pollCq failed: status " + std::to_string(wc->status), wc->status);
}
if (wc->opcode == IBV_WC_RDMA_WRITE) {
numSignaledSends--;

View File

@@ -1,26 +1,32 @@
#include "epoch.hpp"
#include "checks.hpp"
#include "alloc.h"
#include "api.h"
#include "checks.hpp"
namespace mscclpp {
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection) : connection_(connection) {
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection)
: connection_(connection)
{
MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1));
MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1));
localEpochIdsRegMem_ = communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport());
localEpochIdsRegMem_ =
communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport());
communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection->remoteRank(), connection->tag());
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
}
MSCCLPP_API_CPP Epoch::~Epoch() {
MSCCLPP_API_CPP Epoch::~Epoch()
{
mscclppCudaFree(device_.epochIds_);
mscclppCudaFree(device_.expectedInboundEpochId_);
}
MSCCLPP_API_CPP void Epoch::signal() {
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
MSCCLPP_API_CPP void Epoch::signal()
{
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_,
offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
}
} // namespace mscclpp

30
src/errors.cc Normal file
View File

@@ -0,0 +1,30 @@
#include "errors.hpp"
namespace mscclpp {
BaseError::BaseError(std::string message, int errorCode) : std::runtime_error(message), errorCode_(errorCode)
{
}
int BaseError::getErrorCode() const
{
return errorCode_;
}
Error::Error(std::string message, int errorCode) : BaseError(message, errorCode)
{
}
CudaError::CudaError(std::string message, int errorCode) : BaseError(message, errorCode)
{
}
CuError::CuError(std::string message, int errorCode) : BaseError(message, errorCode)
{
}
IbError::IbError(std::string message, int errorCode) : BaseError(message, errorCode)
{
}
}; // namespace mscclpp

View File

@@ -1,7 +1,7 @@
#include "alloc.h"
#include "api.h"
#include "checks.hpp"
#include "mscclppfifo.hpp"
#include "api.h"
#include <cuda_runtime.h>
#include <emmintrin.h>
#include <stdexcept>

View File

@@ -6,12 +6,12 @@
#include <unistd.h>
#include "alloc.h"
#include "api.h"
#include "checks.hpp"
#include "comm.h"
#include "debug.h"
#include "ib.hpp"
#include "mscclpp.hpp"
#include "api.h"
#include <infiniband/verbs.h>
#include <string>
@@ -20,7 +20,7 @@ namespace mscclpp {
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
{
if (size == 0) {
throw std::runtime_error("invalid size: " + std::to_string(size));
throw std::invalid_argument("invalid size: " + std::to_string(size));
}
static __thread uintptr_t pageSize = 0;
if (pageSize == 0) {
@@ -35,7 +35,7 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
if (_mr == nullptr) {
std::stringstream err;
err << "ibv_reg_mr failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
this->mr = _mr;
this->size = pages * pageSize;
@@ -73,7 +73,7 @@ IbQp::IbQp(void* ctx, void* pd, int port)
if (this->cq == nullptr) {
std::stringstream err;
err << "ibv_create_cq failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
struct ibv_qp_init_attr qpInitAttr;
@@ -92,14 +92,14 @@ IbQp::IbQp(void* ctx, void* pd, int port)
if (_qp == nullptr) {
std::stringstream err;
err << "ibv_create_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
struct ibv_port_attr portAttr;
if (ibv_query_port(_ctx, port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
this->info.lid = portAttr.lid;
this->info.port = port;
@@ -111,7 +111,7 @@ IbQp::IbQp(void* ctx, void* pd, int port)
if (ibv_query_gid(_ctx, port, 0, &gid) != 0) {
std::stringstream err;
err << "ibv_query_gid failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
this->info.spn = gid.global.subnet_prefix;
}
@@ -125,7 +125,7 @@ IbQp::IbQp(void* ctx, void* pd, int port)
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
this->qp = _qp;
this->wrn = 0;
@@ -174,7 +174,7 @@ void IbQp::rtr(const IbQpInfo& info)
if (ret != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
}
@@ -194,7 +194,7 @@ void IbQp::rts()
if (ret != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
}
@@ -249,7 +249,7 @@ void IbQp::postSend()
if (ret != 0) {
std::stringstream err;
err << "ibv_post_send failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
this->wrn = 0;
}
@@ -265,7 +265,7 @@ void IbQp::postRecv(uint64_t wrId)
if (ret != 0) {
std::stringstream err;
err << "ibv_post_recv failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
}
@@ -299,13 +299,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName)
if (this->ctx == nullptr) {
std::stringstream err;
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
this->pd = ibv_alloc_pd(reinterpret_cast<struct ibv_context*>(this->ctx));
if (this->pd == nullptr) {
std::stringstream err;
err << "ibv_alloc_pd failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
}
@@ -327,7 +327,7 @@ bool IbCtx::isPortUsable(int port) const
if (ibv_query_port(reinterpret_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
return portAttr.state == IBV_PORT_ACTIVE &&
(portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
@@ -339,7 +339,7 @@ int IbCtx::getAnyActivePort() const
if (ibv_query_device(reinterpret_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw std::runtime_error(err.str());
throw mscclpp::IbError(err.str(), errno);
}
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
if (this->isPortUsable(port)) {
@@ -354,10 +354,10 @@ IbQp* IbCtx::createQp(int port /*=-1*/)
if (port == -1) {
port = this->getAnyActivePort();
if (port == -1) {
throw std::runtime_error("No active port found");
throw mscclpp::Error("No active port found", mscclppInternalError);
}
} else if (!this->isPortUsable(port)) {
throw std::runtime_error("invalid IB port: " + std::to_string(port));
throw mscclpp::Error("invalid IB port: " + std::to_string(port), mscclppInternalError);
}
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
return qps.back().get();
@@ -412,10 +412,10 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport)
ibTransportIndex = 7;
break;
default:
throw std::runtime_error("Not an IB transport");
throw std::invalid_argument("Not an IB transport");
}
if (ibTransportIndex >= num) {
throw std::runtime_error("IB transport out of range");
throw std::out_of_range("IB transport out of range");
}
return devices[ibTransportIndex]->name;
}
@@ -444,11 +444,11 @@ MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDevice
case 7:
return Transport::IB7;
default:
throw std::runtime_error("IB device index out of range");
throw std::out_of_range("IB device index out of range");
}
}
}
throw std::runtime_error("IB device not found");
throw std::invalid_argument("IB device not found");
}
} // namespace mscclpp

View File

@@ -3,8 +3,8 @@
#include "epoch.hpp"
#include "mscclpp.hpp"
#include "proxy.hpp"
#include "mscclppfifo.hpp"
#include "proxy.hpp"
#include "utils.hpp"
namespace mscclpp {
@@ -15,10 +15,16 @@ class Channel
{
public:
Channel(Communicator& communicator, std::shared_ptr<Connection> connection)
: connection_(connection), epoch_(std::make_shared<Epoch>(communicator, connection)) {};
: connection_(connection), epoch_(std::make_shared<Epoch>(communicator, connection)){};
Connection& connection() { return *connection_; }
Epoch& epoch() { return *epoch_; }
Connection& connection()
{
return *connection_;
}
Epoch& epoch()
{
return *epoch_;
}
private:
std::shared_ptr<Connection> connection_;
@@ -69,8 +75,8 @@ union ChannelTrigger {
__device__ ChannelTrigger(ProxyTrigger value) : value(value)
{
}
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size, int connectionId)
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size, int connectionId)
{
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
@@ -86,15 +92,17 @@ struct DeviceChannel
{
DeviceChannel() = default;
DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo) : channelId_(channelId), epoch_(epoch), fifo_(fifo) {}
DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo)
: channelId_(channelId), epoch_(epoch), fifo_(fifo)
{
}
DeviceChannel(const DeviceChannel& other) = default;
DeviceChannel& operator=(DeviceChannel& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size)
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size)
{
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value);
}
@@ -110,13 +118,11 @@ struct DeviceChannel
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value);
}
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size)
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size)
{
epochIncrement();
fifo_.push(
ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_)
.value);
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_).value);
}
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
@@ -128,16 +134,14 @@ struct DeviceChannel
uint64_t srcOffset, uint64_t size)
{
epochIncrement();
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst,
dstOffset, src, srcOffset, size, channelId_)
.value);
uint64_t curFifoHead = fifo_.push(
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, channelId_).value);
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo_.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset,
uint64_t size)
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(dst, offset, src, offset, size);
}
@@ -176,25 +180,40 @@ class DeviceChannelService;
inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService);
class DeviceChannelService {
class DeviceChannelService
{
public:
DeviceChannelService(Communicator& communicator);
ChannelId addChannel(std::shared_ptr<Connection> connection) {
ChannelId addChannel(std::shared_ptr<Connection> connection)
{
channels_.push_back(Channel(communicator_, connection));
return channels_.size() - 1;
}
MemoryId addMemory(RegisteredMemory memory) {
MemoryId addMemory(RegisteredMemory memory)
{
memories_.push_back(memory);
return memories_.size() - 1;
}
Channel channel(ChannelId id) { return channels_[id]; }
DeviceChannel deviceChannel(ChannelId id) { return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo()); }
Channel channel(ChannelId id)
{
return channels_[id];
}
DeviceChannel deviceChannel(ChannelId id)
{
return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo());
}
void startProxy() { proxy_.start(); }
void stopProxy() { proxy_.stop(); }
void startProxy()
{
proxy_.start();
}
void stopProxy()
{
proxy_.stop();
}
private:
Communicator& communicator_;
@@ -205,7 +224,8 @@ private:
void bindThread();
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) {
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw)
{
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
Channel& channel = channels_[trigger->fields.chanId];
@@ -234,7 +254,9 @@ struct SimpleDeviceChannel
{
SimpleDeviceChannel() = default;
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src)
{
}
SimpleDeviceChannel(const SimpleDeviceChannel& other) = default;

View File

@@ -8,6 +8,8 @@
#define MSCCLPP_CHECKS_HPP_
#include "debug.h"
#include "errors.hpp"
#include <cuda.h>
#include <cuda_runtime.h>
@@ -15,7 +17,8 @@
do { \
mscclppResult_t res = call; \
if (res != mscclppSuccess && res != mscclppInProgress) { \
throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \
throw mscclpp::Error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res), \
res); \
} \
} while (false)
@@ -23,7 +26,7 @@
do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \
throw mscclpp::CudaError(std::string("Cuda failure '") + cudaGetErrorString(err) + "'", err); \
} \
} while (false)
@@ -33,7 +36,7 @@
if (err != CUDA_SUCCESS) { \
const char* errStr; \
cuGetErrorString(err, &errStr); \
throw std::runtime_error(std::string("Cu failure '") + std::string(errStr) + "'"); \
throw mscclpp::CuError(std::string("Cu failure '") + std::string(errStr) + "'", err); \
} \
} while (false)

View File

@@ -17,6 +17,7 @@ class ConnectionBase : public Connection, public Setuppable
{
int remoteRank_;
int tag_;
public:
ConnectionBase(int remoteRank, int tag);

View File

@@ -17,7 +17,8 @@ struct DeviceEpoch
__forceinline__ __device__ void wait()
{
(*expectedInboundEpochId_) += 1;
while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_));
while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_))
;
}
__forceinline__ __device__ void epochIncrement()
@@ -44,9 +45,12 @@ public:
void signal();
DeviceEpoch deviceEpoch() { return device_; }
DeviceEpoch deviceEpoch()
{
return device_;
}
};
} // namespace mscclpp
#endif // MSCCLPP_EPOCH_HPP_
#endif // MSCCLPP_EPOCH_HPP_

46
src/include/errors.hpp Normal file
View File

@@ -0,0 +1,46 @@
#ifndef MSCCLPP_ERRORS_HPP_
#define MSCCLPP_ERRORS_HPP_
#include <stdexcept>
namespace mscclpp {
class BaseError : public std::runtime_error
{
public:
BaseError(std::string message, int errorCode);
virtual ~BaseError() = default;
int getErrorCode() const;
private:
int errorCode_;
};
class Error : public BaseError
{
public:
Error(std::string message, int errorCode);
virtual ~Error() = default;
};
class CudaError : public BaseError
{
public:
CudaError(std::string message, int errorCode);
virtual ~CudaError() = default;
};
class CuError : public BaseError
{
public:
CuError(std::string message, int errorCode);
virtual ~CuError() = default;
};
class IbError : public BaseError
{
public:
IbError(std::string message, int errorCode);
virtual ~IbError() = default;
};
}; // namespace mscclpp
#endif // MSCCLPP_ERRORS_HPP

View File

@@ -7,10 +7,10 @@
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
#include <bitset>
#include <future>
#include <memory>
#include <string>
#include <vector>
#include <future>
namespace mscclpp {
@@ -37,14 +37,14 @@ public:
{
size_t size = data.size();
send((void*)&size, sizeof(size_t), peer, tag);
send((void*)data.data(), data.size(), peer, tag+1);
send((void*)data.data(), data.size(), peer, tag + 1);
}
void recv(std::vector<char>& data, int peer, int tag)
{
size_t size;
recv((void*)&size, sizeof(size_t), peer, tag);
data.resize(size);
recv((void*)data.data(), data.size(), peer, tag+1);
recv((void*)data.data(), data.size(), peer, tag + 1);
}
};
@@ -239,7 +239,8 @@ class Connection;
class RegisteredMemory
{
struct Impl;
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated lazily.
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
// lazily.
std::shared_ptr<Impl> pimpl;
public:
@@ -281,17 +282,23 @@ protected:
struct Setuppable
{
virtual void beginSetup(std::shared_ptr<BaseBootstrap>) {}
virtual void endSetup(std::shared_ptr<BaseBootstrap>) {}
virtual void beginSetup(std::shared_ptr<BaseBootstrap>)
{
}
virtual void endSetup(std::shared_ptr<BaseBootstrap>)
{
}
};
template<typename T>
class NonblockingFuture
template <typename T> class NonblockingFuture
{
std::shared_future<T> future;
public:
NonblockingFuture() = default;
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future))
{
}
NonblockingFuture(const NonblockingFuture&) = default;
bool ready() const
@@ -331,7 +338,7 @@ public:
* Returns: a handle to the buffer
*/
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag);
NonblockingFuture<RegisteredMemory> recvMemoryOnSetup(int remoteRank, int tag);
@@ -363,7 +370,6 @@ public:
private:
std::unique_ptr<Impl> pimpl;
};
} // namespace mscclpp
namespace std {

View File

@@ -2,6 +2,7 @@
#define MSCCLPP_REGISTERED_MEMORY_HPP_
#include "communicator.hpp"
#include "errors.hpp"
#include "ib.hpp"
#include "mscclpp.h"
#include "mscclpp.hpp"
@@ -16,11 +17,13 @@ struct TransportInfo
// TODO: rewrite this using std::variant or something
bool ibLocal;
union {
struct {
struct
{
cudaIpcMemHandle_t cudaIpcBaseHandle;
size_t cudaIpcOffsetFromBase;
};
struct {
struct
{
const IbMr* ibMr;
IbMrInfo ibMrInfo;
};
@@ -46,7 +49,7 @@ struct RegisteredMemory::Impl
return entry;
}
}
throw std::runtime_error("Transport data not found");
throw Error("Transport data not found", mscclppInternalError);
}
};

View File

@@ -8,45 +8,45 @@ namespace mscclpp {
struct Timer
{
std::chrono::steady_clock::time_point start;
Timer()
{
start = std::chrono::steady_clock::now();
}
int64_t elapsed()
{
auto end = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
}
void reset()
{
start = std::chrono::steady_clock::now();
}
void print(const char* name)
{
auto end = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
printf("%s: %ld us\n", name, elapsed);
}
std::chrono::steady_clock::time_point start;
Timer()
{
start = std::chrono::steady_clock::now();
}
int64_t elapsed()
{
auto end = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
}
void reset()
{
start = std::chrono::steady_clock::now();
}
void print(const char* name)
{
auto end = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
printf("%s: %ld us\n", name, elapsed);
}
};
struct ScopedTimer
{
Timer timer;
const char* name;
ScopedTimer(const char* name) : name(name)
{
}
~ScopedTimer()
{
timer.print(name);
}
Timer timer;
const char* name;
ScopedTimer(const char* name) : name(name)
{
}
~ScopedTimer()
{
timer.print(name);
}
};
} // namespace mscclpp

View File

@@ -1,6 +1,6 @@
#include "proxy.hpp"
#include "api.h"
#include "mscclpp.hpp"
#include "proxy.hpp"
#include "utils.h"
#include "utils.hpp"
#include <atomic>
@@ -20,7 +20,8 @@ struct Proxy::Impl
std::thread service;
std::atomic_bool running;
Impl(ProxyHandler handler, std::function<void()> threadInit) : handler(handler), threadInit(threadInit), running(false)
Impl(ProxyHandler handler, std::function<void()> threadInit)
: handler(handler), threadInit(threadInit), running(false)
{
}
};
@@ -45,7 +46,6 @@ MSCCLPP_API_CPP void Proxy::start()
{
pimpl->running = true;
pimpl->service = std::thread([this] {
pimpl->threadInit();
ProxyHandler handler = this->pimpl->handler;
@@ -109,4 +109,4 @@ MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo()
return pimpl->fifo;
}
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -88,7 +88,7 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
throw std::runtime_error("Too many transport info entries");
throw mscclpp::Error("Too many transport info entries", mscclppInternalError);
}
int8_t transportCount = pimpl->transportInfos.size();
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
@@ -102,7 +102,7 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
} else if (AllIBTransports.has(entry.transport)) {
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
} else {
throw std::runtime_error("Unknown transport");
throw mscclpp::Error("Unknown transport", mscclppInternalError);
}
}
return result;
@@ -132,21 +132,23 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
it += sizeof(transportInfo.transport);
if (transportInfo.transport == Transport::CudaIpc) {
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle),
reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
it += sizeof(transportInfo.cudaIpcBaseHandle);
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase),
reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
} else if (AllIBTransports.has(transportInfo.transport)) {
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
it += sizeof(transportInfo.ibMrInfo);
transportInfo.ibLocal = false;
} else {
throw std::runtime_error("Unknown transport");
throw mscclpp::Error("Unknown transport", mscclppInternalError);
}
this->transportInfos.push_back(transportInfo);
}
if (it != serialization.end()) {
throw std::runtime_error("Deserialization failed");
throw mscclpp::Error("Serialization failed", mscclppInternalError);
}
if (transports.has(Transport::CudaIpc)) {

View File

@@ -6,10 +6,10 @@
#include "utils.h"
#include <memory>
#include <numa.h>
#include <stdlib.h>
#include <string>
#include <memory>
// Get current Compute Capability
// int mscclppCudaCompCap() {

View File

@@ -1,5 +1,6 @@
#include "mscclpp.h"
#include "mscclpp.hpp"
#include "channel.hpp"
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
@@ -71,8 +72,8 @@ __device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int ra
devChan.wait();
}
__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
int remoteRank, uint64_t offset, uint64_t size)
__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size,
int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size)
{
// this allgather algorithm works as follows:
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
@@ -131,7 +132,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
// opposite side
if ((threadIdx.x % 32) == 0)
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int),
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0)
devChan.wait();
}
@@ -150,9 +151,8 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
// opposite side
if ((threadIdx.x % 32) == 0)
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
sizeof(int),
nelemsPerGPU / pipelineSize * sizeof(int));
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
nelemsPerGPU / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0)
devChan.wait();
}
@@ -226,7 +226,8 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz
CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice));
}
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize)
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm,
mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize)
{
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
@@ -258,12 +259,13 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
std::vector<mscclpp::channel::SimpleDeviceChannel> devChannels;
for (size_t i = 0; i < channelIds.size(); ++i) {
devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(channelService.deviceChannel(channelIds[i]),
channelService.addMemory(remoteMemories[i].get()), channelService.addMemory(localMemories[i])));
channelService.addMemory(remoteMemories[i].get()),
channelService.addMemory(localMemories[i])));
}
assert(devChannels.size() < sizeof(constDevChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
CUDACHECK(
cudaMemcpyToSymbol(constDevChans, devChannels.data(), sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size()));
CUDACHECK(cudaMemcpyToSymbol(constDevChans, devChannels.data(),
sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size()));
}
void printUsage(const char* prog, bool isMpi)

View File

@@ -1,5 +1,5 @@
#include "mscclpp.hpp"
#include "epoch.hpp"
#include "mscclpp.hpp"
#include <cassert>
#include <cuda_runtime.h>
@@ -24,26 +24,33 @@ mscclpp::Transport findIb(int localRank)
return IBs[localRank];
}
void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemory){
void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr,
size_t deviceBufferSize, mscclpp::Transport myIbDevice,
mscclpp::RegisteredMemory& localMemory,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemory)
{
localMemory = communicator.registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice);
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemory;
for (int i = 0; i < worldSize; i++) {
if (i != rank){
if (i != rank) {
communicator.sendMemoryOnSetup(localMemory, i, 0);
futureRemoteMemory[i] = communicator.recvMemoryOnSetup(i, 0);
}
}
communicator.setup();
for (int i = 0; i < worldSize; i++) {
if (i != rank){
if (i != rank) {
remoteMemory[i] = futureRemoteMemory[i].get();
}
}
}
void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections){
void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode,
mscclpp::Transport myIbDevice,
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections)
{
for (int i = 0; i < worldSize; i++) {
if (i != rank){
if (i != rank) {
if (i / nRanksPerNode == rank / nRanksPerNode) {
connections[i] = communicator.connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
} else {
@@ -54,35 +61,40 @@ void make_connections(mscclpp::Communicator& communicator, int rank, int worldSi
communicator.setup();
}
void write_remote(int rank, int worldSize, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank){
void write_remote(int rank, int worldSize, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteRegisteredMemories,
mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank)
{
for (int i = 0; i < worldSize; i++) {
if (i != rank) {
auto& conn = connections.at(i);
auto& peerMemory = remoteRegisteredMemories.at(i);
conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, rank * dataCountPerRank*sizeof(int), dataCountPerRank*sizeof(int));
conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory,
rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int));
conn->flush();
}
}
}
void device_buffer_init(int rank, int worldSize, int dataCount, std::vector<int*>& devicePtr){
for (int n = 0; n < (int)devicePtr.size(); n++){
void device_buffer_init(int rank, int worldSize, int dataCount, std::vector<int*>& devicePtr)
{
for (int n = 0; n < (int)devicePtr.size(); n++) {
std::vector<int> hostBuffer(dataCount, 0);
for (int i = 0; i < dataCount; i++) {
hostBuffer[i] = rank + n * worldSize;
}
CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount*sizeof(int), cudaMemcpyHostToDevice));
CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount * sizeof(int), cudaMemcpyHostToDevice));
}
CUDATHROW(cudaDeviceSynchronize());
}
bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector<int*>& devicePtr){
for (int n = 0; n < (int)devicePtr.size(); n++){
bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector<int*>& devicePtr)
{
for (int n = 0; n < (int)devicePtr.size(); n++) {
std::vector<int> hostBuffer(dataCount, 0);
CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount*sizeof(int), cudaMemcpyDeviceToHost));
CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount * sizeof(int), cudaMemcpyDeviceToHost));
for (int i = 0; i < worldSize; i++) {
for (int j = i*dataCount/worldSize; j < (i+1)*dataCount/worldSize; j++) {
for (int j = i * dataCount / worldSize; j < (i + 1) * dataCount / worldSize; j++) {
if (hostBuffer[j] != i + n * worldSize) {
return false;
}
@@ -92,8 +104,11 @@ bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vec
return true;
}
void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, int numBuffers){
void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, int numBuffers)
{
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
size_t dataCount = deviceBufferSize / sizeof(int);
@@ -102,8 +117,8 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "CUDA memory initialization passed" << std::endl;
for (int n = 0; n < numBuffers; n++){
for (int n = 0; n < numBuffers; n++) {
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
}
bootstrap->barrier();
@@ -116,7 +131,7 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
do {
ready = test_device_buffer_write_correctness(worldSize, dataCount, devicePtr);
niter++;
if (niter == 10000){
if (niter == 10000) {
throw std::runtime_error("Polling is stuck.");
}
} while (!ready);
@@ -126,22 +141,29 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
}
__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize)
{
int tid = threadIdx.x;
if (tid != rank && tid < worldSize){
if (tid != rank && tid < worldSize) {
deviceEpochs[tid].epochIncrement();
}
}
__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize)
{
int tid = threadIdx.x;
if (tid != rank && tid < worldSize){
if (tid != rank && tid < worldSize) {
deviceEpochs[tid].wait();
}
}
void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs, int numBuffers){
void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize,
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr,
std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs, int numBuffers)
{
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
size_t dataCount = deviceBufferSize / sizeof(int);
@@ -153,8 +175,8 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::
mscclpp::DeviceEpoch* deviceEpochs;
CUDATHROW(cudaMalloc(&deviceEpochs, sizeof(mscclpp::DeviceEpoch) * worldSize));
for (int i = 0; i < worldSize; i++){
if (i != rank){
for (int i = 0; i < worldSize; i++) {
if (i != rank) {
mscclpp::DeviceEpoch deviceEpoch = epochs[i]->deviceEpoch();
CUDATHROW(cudaMemcpy(&deviceEpochs[i], &deviceEpoch, sizeof(mscclpp::DeviceEpoch), cudaMemcpyHostToDevice));
}
@@ -165,16 +187,15 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::
if (bootstrap->getRank() == 0)
std::cout << "CUDA device epochs are created" << std::endl;
for (int n = 0; n < numBuffers; n++){
for (int n = 0; n < numBuffers; n++) {
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
}
increament_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
CUDATHROW(cudaDeviceSynchronize());
for (int i = 0; i < worldSize; i++){
if (i != rank){
for (int i = 0; i < worldSize; i++) {
if (i != rank) {
epochs[i]->signal();
}
}
@@ -182,13 +203,14 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::
wait_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
CUDATHROW(cudaDeviceSynchronize());
if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)){
if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)) {
throw std::runtime_error("unexpected result.");
}
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---" << std::endl;
std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---"
<< std::endl;
}
void test_communicator(int rank, int worldSize, int nranksPerNode)
@@ -213,8 +235,8 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
int numBuffers = 10;
std::vector<int*> devicePtr(numBuffers);
int deviceBufferSize = 1024*1024;
int deviceBufferSize = 1024 * 1024;
std::vector<mscclpp::RegisteredMemory> localMemory(numBuffers);
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remoteMemory(numBuffers);
@@ -222,13 +244,15 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
if (n % 100 == 0)
std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl;
CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize));
register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]);
register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n],
remoteMemory[n]);
}
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, numBuffers);
test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr,
numBuffers);
if (bootstrap->getRank() == 0)
std::cout << "--- Testing vanialla writes passed ---" << std::endl;
@@ -242,12 +266,13 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
if (bootstrap->getRank() == 0)
std::cout << "Epochs are created" << std::endl;
test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, epochs, numBuffers);
test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory,
devicePtr, epochs, numBuffers);
if (bootstrap->getRank() == 0)
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
for (int n = 0; n < numBuffers; n++){
for (int n = 0; n < numBuffers; n++) {
CUDATHROW(cudaFree(devicePtr[n]));
}
}
@@ -269,4 +294,4 @@ int main(int argc, char** argv)
MPI_Finalize();
return 0;
}
}