diff --git a/Makefile b/Makefile index 2b80afb5..74d2c475 100644 --- a/Makefile +++ b/Makefile @@ -121,7 +121,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc) -LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc) +LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc errors.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) endif @@ -135,7 +135,7 @@ HEADERS := $(wildcard src/include/*.h) CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*") PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)') -INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp +INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp errors.hpp INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%) LIBNAME := libmscclpp.so diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 75225799..50227234 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -194,13 +194,15 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector< MSCCLPPTHROW(mscclppSocketClose(&sock)); if (this->nRanks_ != info.nRanks) { - throw std::runtime_error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) + - " : " + std::to_string(info.nRanks)); + throw mscclpp::Error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) + " : " + + std::to_string(info.nRanks), + mscclppInternalError); } if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) { - throw std::runtime_error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + - std::to_string(this->nRanks_) + " has already checked in"); + throw mscclpp::Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) + + " has already checked in", + mscclppInternalError); } // Save the connection handle for that rank @@ -269,16 +271,17 @@ void Bootstrap::Impl::netInit(std::string ipPortPair) if (!ipPortPair.empty()) { mscclppSocketAddress remoteAddr; if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) { - throw std::runtime_error( - "Invalid ipPortPair, please use format: : or []: or :"); + throw mscclpp::Error( + "Invalid ipPortPair, please use format: : or []: or :", + mscclppInvalidArgument); } if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { - throw std::runtime_error("NET/Socket : No usable listening interface found"); + throw mscclpp::Error("NET/Socket : No usable listening interface found", mscclppInternalError); } } else { int ret = mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1); if (ret <= 0) { - throw std::runtime_error("Bootstrap : no socket interface found"); + throw mscclpp::Error("Bootstrap : no socket interface found", mscclppInternalError); } } @@ -390,8 +393,9 @@ void Bootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size) int recvSize; MSCCLPPTHROW(mscclppSocketRecv(sock, &recvSize, sizeof(int))); if (recvSize > size) { - throw std::runtime_error("Message truncated : received " + std::to_string(recvSize) + " bytes instead of " + - std::to_string(size)); + throw mscclpp::Error("Message truncated : received " + std::to_string(recvSize) + " bytes instead of " + + std::to_string(size), + mscclppInternalError); } MSCCLPPTHROW(mscclppSocketRecv(sock, data, std::min(recvSize, size))); } @@ -1058,4 +1062,4 @@ mscclppResult_t bootstrapAbort(void* commState) free(state->peerProxyAddresses); free(state); return mscclppSuccess; -} \ No newline at end of file +} diff --git a/src/channel.cc b/src/channel.cc index 42572390..33b679c2 100644 --- a/src/channel.cc +++ b/src/channel.cc @@ -1,14 +1,16 @@ #include "channel.hpp" -#include "utils.h" -#include "checks.hpp" #include "api.h" +#include "checks.hpp" #include "debug.h" +#include "utils.h" namespace mscclpp { namespace channel { -MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator), - proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { +MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) + : communicator_(communicator), + proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) +{ int cudaDevice; CUDATHROW(cudaGetDevice(&cudaDevice)); MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode)); @@ -23,4 +25,4 @@ MSCCLPP_API_CPP void DeviceChannelService::bindThread() } } // namespace channel -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/communicator.cc b/src/communicator.cc index 1fd64132..9b28cbb4 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -59,8 +59,9 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t struct MemorySender : public Setuppable { - MemorySender(RegisteredMemory memory, int remoteRank, int tag) - : memory_(memory), remoteRank_(remoteRank), tag_(tag) {} + MemorySender(RegisteredMemory memory, int remoteRank, int tag) : memory_(memory), remoteRank_(remoteRank), tag_(tag) + { + } void beginSetup(std::shared_ptr bootstrap) override { @@ -79,8 +80,9 @@ MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, in struct MemoryReceiver : public Setuppable { - MemoryReceiver(int remoteRank, int tag) - : remoteRank_(remoteRank), tag_(tag) {} + MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) + { + } void endSetup(std::shared_ptr bootstrap) override { @@ -112,7 +114,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")" << " != " << pimpl->bootstrap_->getRank() << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"; - throw std::runtime_error(ss.str()); + throw mscclpp::Error(ss.str(), mscclppInternalError); } auto cudaIpcConn = std::make_shared(remoteRank, tag); conn = cudaIpcConn; @@ -126,7 +128,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]); } else { - throw std::runtime_error("Unsupported transport"); + throw mscclpp::Error("Unsupported transport", mscclppInvalidArgument); } pimpl->connections_.push_back(conn); addSetup(conn); diff --git a/src/connection.cc b/src/connection.cc index 6a657e02..60ff2291 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -1,17 +1,17 @@ -#include #include "connection.hpp" #include "checks.hpp" #include "infiniband/verbs.h" #include "npkit/npkit.h" #include "registered_memory.hpp" #include "utils.hpp" +#include namespace mscclpp { void validateTransport(RegisteredMemory mem, Transport transport) { if (!mem.transports().has(transport)) { - throw std::runtime_error("mem does not support transport"); + throw Error("RegisteredMemory does not support transport", mscclppInvalidArgument); } } @@ -24,11 +24,19 @@ std::shared_ptr Connection::getRegisteredMemoryImpl(Regi // ConnectionBase -ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {} +ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) +{ +} -int ConnectionBase::remoteRank() { return remoteRank_; } +int ConnectionBase::remoteRank() +{ + return remoteRank_; +} -int ConnectionBase::tag() { return tag_; } +int ConnectionBase::tag() +{ + return tag_; +} // CudaIpcConnection @@ -99,11 +107,11 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport()); if (dstTransportInfo.ibLocal) { - throw std::runtime_error("dst is local, which is not supported"); + throw Error("dst is local, which is not supported", mscclppInvalidArgument); } auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport()); if (!srcTransportInfo.ibLocal) { - throw std::runtime_error("src is remote, which is not supported"); + throw Error("src is remote, which is not supported", mscclppInvalidArgument); } auto dstMrInfo = dstTransportInfo.ibMrInfo; @@ -113,7 +121,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem /*signaled=*/true); numSignaledSends++; qp->postSend(); - INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, (uint8_t*)dstMrInfo.addr + dstOffset, size); + INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, + (uint8_t*)dstMrInfo.addr + dstOffset, size); // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } @@ -123,16 +132,19 @@ void IBConnection::flush() while (numSignaledSends) { int wcNum = qp->pollCq(); if (wcNum < 0) { - throw std::runtime_error("pollCq failed: error no " + std::to_string(errno)); + throw mscclpp::IbError("pollCq failed: error no " + std::to_string(errno), errno); } auto elapsed = timer.elapsed(); - if (elapsed > MSCCLPP_POLLING_WAIT) - throw std::runtime_error("pollCq is stuck: waited for " + std::to_string(elapsed) + " seconds. Expected " + std::to_string(numSignaledSends) + " signals"); + if (elapsed > MSCCLPP_POLLING_WAIT) { + throw Error("pollCq is stuck: waited for " + std::to_string(elapsed) + " seconds. Expected " + + std::to_string(numSignaledSends) + " signals", + mscclppInternalError); + } for (int i = 0; i < wcNum; ++i) { const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { - throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status)); + throw mscclpp::IbError("pollCq failed: status " + std::to_string(wc->status), wc->status); } if (wc->opcode == IBV_WC_RDMA_WRITE) { numSignaledSends--; diff --git a/src/epoch.cc b/src/epoch.cc index 9263fd1c..2e3a5166 100644 --- a/src/epoch.cc +++ b/src/epoch.cc @@ -1,26 +1,32 @@ #include "epoch.hpp" -#include "checks.hpp" #include "alloc.h" #include "api.h" +#include "checks.hpp" namespace mscclpp { -MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) : connection_(connection) { +MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) + : connection_(connection) +{ MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1)); MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1)); - localEpochIdsRegMem_ = communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport()); + localEpochIdsRegMem_ = + communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport()); communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection->remoteRank(), connection->tag()); remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag()); } -MSCCLPP_API_CPP Epoch::~Epoch() { +MSCCLPP_API_CPP Epoch::~Epoch() +{ mscclppCudaFree(device_.epochIds_); mscclppCudaFree(device_.expectedInboundEpochId_); } -MSCCLPP_API_CPP void Epoch::signal() { - connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_)); +MSCCLPP_API_CPP void Epoch::signal() +{ + connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, + offsetof(EpochIds, outbound_), sizeof(device_.epochIds_)); } } // namespace mscclpp diff --git a/src/errors.cc b/src/errors.cc new file mode 100644 index 00000000..d893578c --- /dev/null +++ b/src/errors.cc @@ -0,0 +1,30 @@ +#include "errors.hpp" + +namespace mscclpp { + +BaseError::BaseError(std::string message, int errorCode) : std::runtime_error(message), errorCode_(errorCode) +{ +} + +int BaseError::getErrorCode() const +{ + return errorCode_; +} + +Error::Error(std::string message, int errorCode) : BaseError(message, errorCode) +{ +} + +CudaError::CudaError(std::string message, int errorCode) : BaseError(message, errorCode) +{ +} + +CuError::CuError(std::string message, int errorCode) : BaseError(message, errorCode) +{ +} + +IbError::IbError(std::string message, int errorCode) : BaseError(message, errorCode) +{ +} + +}; // namespace mscclpp diff --git a/src/fifo.cc b/src/fifo.cc index d5d70422..49902816 100644 --- a/src/fifo.cc +++ b/src/fifo.cc @@ -1,7 +1,7 @@ #include "alloc.h" +#include "api.h" #include "checks.hpp" #include "mscclppfifo.hpp" -#include "api.h" #include #include #include diff --git a/src/ib.cc b/src/ib.cc index 1e3e0af6..b95bfb43 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -6,12 +6,12 @@ #include #include "alloc.h" +#include "api.h" #include "checks.hpp" #include "comm.h" #include "debug.h" #include "ib.hpp" #include "mscclpp.hpp" -#include "api.h" #include #include @@ -20,7 +20,7 @@ namespace mscclpp { IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) { if (size == 0) { - throw std::runtime_error("invalid size: " + std::to_string(size)); + throw std::invalid_argument("invalid size: " + std::to_string(size)); } static __thread uintptr_t pageSize = 0; if (pageSize == 0) { @@ -35,7 +35,7 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) if (_mr == nullptr) { std::stringstream err; err << "ibv_reg_mr failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } this->mr = _mr; this->size = pages * pageSize; @@ -73,7 +73,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) if (this->cq == nullptr) { std::stringstream err; err << "ibv_create_cq failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } struct ibv_qp_init_attr qpInitAttr; @@ -92,14 +92,14 @@ IbQp::IbQp(void* ctx, void* pd, int port) if (_qp == nullptr) { std::stringstream err; err << "ibv_create_qp failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } struct ibv_port_attr portAttr; if (ibv_query_port(_ctx, port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } this->info.lid = portAttr.lid; this->info.port = port; @@ -111,7 +111,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) if (ibv_query_gid(_ctx, port, 0, &gid) != 0) { std::stringstream err; err << "ibv_query_gid failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } this->info.spn = gid.global.subnet_prefix; } @@ -125,7 +125,7 @@ IbQp::IbQp(void* ctx, void* pd, int port) if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { std::stringstream err; err << "ibv_modify_qp failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } this->qp = _qp; this->wrn = 0; @@ -174,7 +174,7 @@ void IbQp::rtr(const IbQpInfo& info) if (ret != 0) { std::stringstream err; err << "ibv_modify_qp failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } } @@ -194,7 +194,7 @@ void IbQp::rts() if (ret != 0) { std::stringstream err; err << "ibv_modify_qp failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } } @@ -249,7 +249,7 @@ void IbQp::postSend() if (ret != 0) { std::stringstream err; err << "ibv_post_send failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } this->wrn = 0; } @@ -265,7 +265,7 @@ void IbQp::postRecv(uint64_t wrId) if (ret != 0) { std::stringstream err; err << "ibv_post_recv failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } } @@ -299,13 +299,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) if (this->ctx == nullptr) { std::stringstream err; err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } this->pd = ibv_alloc_pd(reinterpret_cast(this->ctx)); if (this->pd == nullptr) { std::stringstream err; err << "ibv_alloc_pd failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } } @@ -327,7 +327,7 @@ bool IbCtx::isPortUsable(int port) const if (ibv_query_port(reinterpret_cast(this->ctx), port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } return portAttr.state == IBV_PORT_ACTIVE && (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND); @@ -339,7 +339,7 @@ int IbCtx::getAnyActivePort() const if (ibv_query_device(reinterpret_cast(this->ctx), &devAttr) != 0) { std::stringstream err; err << "ibv_query_device failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); + throw mscclpp::IbError(err.str(), errno); } for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { if (this->isPortUsable(port)) { @@ -354,10 +354,10 @@ IbQp* IbCtx::createQp(int port /*=-1*/) if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { - throw std::runtime_error("No active port found"); + throw mscclpp::Error("No active port found", mscclppInternalError); } } else if (!this->isPortUsable(port)) { - throw std::runtime_error("invalid IB port: " + std::to_string(port)); + throw mscclpp::Error("invalid IB port: " + std::to_string(port), mscclppInternalError); } qps.emplace_back(new IbQp(this->ctx, this->pd, port)); return qps.back().get(); @@ -412,10 +412,10 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) ibTransportIndex = 7; break; default: - throw std::runtime_error("Not an IB transport"); + throw std::invalid_argument("Not an IB transport"); } if (ibTransportIndex >= num) { - throw std::runtime_error("IB transport out of range"); + throw std::out_of_range("IB transport out of range"); } return devices[ibTransportIndex]->name; } @@ -444,11 +444,11 @@ MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDevice case 7: return Transport::IB7; default: - throw std::runtime_error("IB device index out of range"); + throw std::out_of_range("IB device index out of range"); } } } - throw std::runtime_error("IB device not found"); + throw std::invalid_argument("IB device not found"); } } // namespace mscclpp diff --git a/src/include/channel.hpp b/src/include/channel.hpp index eb4bd9e7..26d31731 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -3,8 +3,8 @@ #include "epoch.hpp" #include "mscclpp.hpp" -#include "proxy.hpp" #include "mscclppfifo.hpp" +#include "proxy.hpp" #include "utils.hpp" namespace mscclpp { @@ -15,10 +15,16 @@ class Channel { public: Channel(Communicator& communicator, std::shared_ptr connection) - : connection_(connection), epoch_(std::make_shared(communicator, connection)) {}; + : connection_(connection), epoch_(std::make_shared(communicator, connection)){}; - Connection& connection() { return *connection_; } - Epoch& epoch() { return *epoch_; } + Connection& connection() + { + return *connection_; + } + Epoch& epoch() + { + return *epoch_; + } private: std::shared_ptr connection_; @@ -69,8 +75,8 @@ union ChannelTrigger { __device__ ChannelTrigger(ProxyTrigger value) : value(value) { } - __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, - uint64_t srcOffset, uint64_t size, int connectionId) + __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t size, int connectionId) { value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size); value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst) @@ -86,15 +92,17 @@ struct DeviceChannel { DeviceChannel() = default; - DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo) : channelId_(channelId), epoch_(epoch), fifo_(fifo) {} + DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo) + : channelId_(channelId), epoch_(epoch), fifo_(fifo) + { + } DeviceChannel(const DeviceChannel& other) = default; DeviceChannel& operator=(DeviceChannel& other) = default; #ifdef __CUDACC__ - __forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, - uint64_t size) + __forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) { fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value); } @@ -110,13 +118,11 @@ struct DeviceChannel fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value); } - __forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, - uint64_t srcOffset, uint64_t size) + __forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t size) { epochIncrement(); - fifo_.push( - ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_) - .value); + fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_).value); } __forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) @@ -128,16 +134,14 @@ struct DeviceChannel uint64_t srcOffset, uint64_t size) { epochIncrement(); - uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, - dstOffset, src, srcOffset, size, channelId_) - .value); + uint64_t curFifoHead = fifo_.push( + ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, channelId_).value); while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && *(volatile uint64_t*)fifo_.tailReplica <= curFifoHead) ; } - __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, - uint64_t size) + __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { putWithSignalAndFlush(dst, offset, src, offset, size); } @@ -176,25 +180,40 @@ class DeviceChannelService; inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService); -class DeviceChannelService { +class DeviceChannelService +{ public: DeviceChannelService(Communicator& communicator); - ChannelId addChannel(std::shared_ptr connection) { + ChannelId addChannel(std::shared_ptr connection) + { channels_.push_back(Channel(communicator_, connection)); return channels_.size() - 1; } - MemoryId addMemory(RegisteredMemory memory) { + MemoryId addMemory(RegisteredMemory memory) + { memories_.push_back(memory); return memories_.size() - 1; } - Channel channel(ChannelId id) { return channels_[id]; } - DeviceChannel deviceChannel(ChannelId id) { return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo()); } + Channel channel(ChannelId id) + { + return channels_[id]; + } + DeviceChannel deviceChannel(ChannelId id) + { + return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo()); + } - void startProxy() { proxy_.start(); } - void stopProxy() { proxy_.stop(); } + void startProxy() + { + proxy_.start(); + } + void stopProxy() + { + proxy_.stop(); + } private: Communicator& communicator_; @@ -205,7 +224,8 @@ private: void bindThread(); - ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) { + ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) + { ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); Channel& channel = channels_[trigger->fields.chanId]; @@ -234,7 +254,9 @@ struct SimpleDeviceChannel { SimpleDeviceChannel() = default; - SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {} + SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) + { + } SimpleDeviceChannel(const SimpleDeviceChannel& other) = default; diff --git a/src/include/checks.hpp b/src/include/checks.hpp index 6473c92f..b385d6d3 100644 --- a/src/include/checks.hpp +++ b/src/include/checks.hpp @@ -8,6 +8,8 @@ #define MSCCLPP_CHECKS_HPP_ #include "debug.h" +#include "errors.hpp" + #include #include @@ -15,7 +17,8 @@ do { \ mscclppResult_t res = call; \ if (res != mscclppSuccess && res != mscclppInProgress) { \ - throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \ + throw mscclpp::Error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res), \ + res); \ } \ } while (false) @@ -23,7 +26,7 @@ do { \ cudaError_t err = cmd; \ if (err != cudaSuccess) { \ - throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \ + throw mscclpp::CudaError(std::string("Cuda failure '") + cudaGetErrorString(err) + "'", err); \ } \ } while (false) @@ -33,7 +36,7 @@ if (err != CUDA_SUCCESS) { \ const char* errStr; \ cuGetErrorString(err, &errStr); \ - throw std::runtime_error(std::string("Cu failure '") + std::string(errStr) + "'"); \ + throw mscclpp::CuError(std::string("Cu failure '") + std::string(errStr) + "'", err); \ } \ } while (false) diff --git a/src/include/connection.hpp b/src/include/connection.hpp index e06c426a..5f764b05 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -17,6 +17,7 @@ class ConnectionBase : public Connection, public Setuppable { int remoteRank_; int tag_; + public: ConnectionBase(int remoteRank, int tag); diff --git a/src/include/epoch.hpp b/src/include/epoch.hpp index ffd7464d..2566a273 100644 --- a/src/include/epoch.hpp +++ b/src/include/epoch.hpp @@ -17,7 +17,8 @@ struct DeviceEpoch __forceinline__ __device__ void wait() { (*expectedInboundEpochId_) += 1; - while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_)); + while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_)) + ; } __forceinline__ __device__ void epochIncrement() @@ -44,9 +45,12 @@ public: void signal(); - DeviceEpoch deviceEpoch() { return device_; } + DeviceEpoch deviceEpoch() + { + return device_; + } }; } // namespace mscclpp -#endif // MSCCLPP_EPOCH_HPP_ \ No newline at end of file +#endif // MSCCLPP_EPOCH_HPP_ diff --git a/src/include/errors.hpp b/src/include/errors.hpp new file mode 100644 index 00000000..5f58f766 --- /dev/null +++ b/src/include/errors.hpp @@ -0,0 +1,46 @@ +#ifndef MSCCLPP_ERRORS_HPP_ +#define MSCCLPP_ERRORS_HPP_ + +#include + +namespace mscclpp { +class BaseError : public std::runtime_error +{ +public: + BaseError(std::string message, int errorCode); + virtual ~BaseError() = default; + int getErrorCode() const; + +private: + int errorCode_; +}; + +class Error : public BaseError +{ +public: + Error(std::string message, int errorCode); + virtual ~Error() = default; +}; + +class CudaError : public BaseError +{ +public: + CudaError(std::string message, int errorCode); + virtual ~CudaError() = default; +}; + +class CuError : public BaseError +{ +public: + CuError(std::string message, int errorCode); + virtual ~CuError() = default; +}; + +class IbError : public BaseError +{ +public: + IbError(std::string message, int errorCode); + virtual ~IbError() = default; +}; +}; // namespace mscclpp +#endif // MSCCLPP_ERRORS_HPP diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 47ca9437..a37195d3 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -7,10 +7,10 @@ #define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH) #include +#include #include #include #include -#include namespace mscclpp { @@ -37,14 +37,14 @@ public: { size_t size = data.size(); send((void*)&size, sizeof(size_t), peer, tag); - send((void*)data.data(), data.size(), peer, tag+1); + send((void*)data.data(), data.size(), peer, tag + 1); } void recv(std::vector& data, int peer, int tag) { size_t size; recv((void*)&size, sizeof(size_t), peer, tag); data.resize(size); - recv((void*)data.data(), data.size(), peer, tag+1); + recv((void*)data.data(), data.size(), peer, tag + 1); } }; @@ -239,7 +239,8 @@ class Connection; class RegisteredMemory { struct Impl; - // A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated lazily. + // A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated + // lazily. std::shared_ptr pimpl; public: @@ -281,17 +282,23 @@ protected: struct Setuppable { - virtual void beginSetup(std::shared_ptr) {} - virtual void endSetup(std::shared_ptr) {} + virtual void beginSetup(std::shared_ptr) + { + } + virtual void endSetup(std::shared_ptr) + { + } }; -template -class NonblockingFuture +template class NonblockingFuture { std::shared_future future; + public: NonblockingFuture() = default; - NonblockingFuture(std::shared_future&& future) : future(std::move(future)) {} + NonblockingFuture(std::shared_future&& future) : future(std::move(future)) + { + } NonblockingFuture(const NonblockingFuture&) = default; bool ready() const @@ -331,7 +338,7 @@ public: * Returns: a handle to the buffer */ RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); - + void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag); NonblockingFuture recvMemoryOnSetup(int remoteRank, int tag); @@ -363,7 +370,6 @@ public: private: std::unique_ptr pimpl; }; - } // namespace mscclpp namespace std { diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index bf4802ce..d1b7830d 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -2,6 +2,7 @@ #define MSCCLPP_REGISTERED_MEMORY_HPP_ #include "communicator.hpp" +#include "errors.hpp" #include "ib.hpp" #include "mscclpp.h" #include "mscclpp.hpp" @@ -16,11 +17,13 @@ struct TransportInfo // TODO: rewrite this using std::variant or something bool ibLocal; union { - struct { + struct + { cudaIpcMemHandle_t cudaIpcBaseHandle; size_t cudaIpcOffsetFromBase; }; - struct { + struct + { const IbMr* ibMr; IbMrInfo ibMrInfo; }; @@ -46,7 +49,7 @@ struct RegisteredMemory::Impl return entry; } } - throw std::runtime_error("Transport data not found"); + throw Error("Transport data not found", mscclppInternalError); } }; diff --git a/src/include/utils.hpp b/src/include/utils.hpp index 9abf9994..d1a1c7d8 100644 --- a/src/include/utils.hpp +++ b/src/include/utils.hpp @@ -8,45 +8,45 @@ namespace mscclpp { struct Timer { - std::chrono::steady_clock::time_point start; - - Timer() - { - start = std::chrono::steady_clock::now(); - } - - int64_t elapsed() - { - auto end = std::chrono::steady_clock::now(); - return std::chrono::duration_cast(end - start).count(); - } - - void reset() - { - start = std::chrono::steady_clock::now(); - } - - void print(const char* name) - { - auto end = std::chrono::steady_clock::now(); - auto elapsed = std::chrono::duration_cast(end - start).count(); - printf("%s: %ld us\n", name, elapsed); - } + std::chrono::steady_clock::time_point start; + + Timer() + { + start = std::chrono::steady_clock::now(); + } + + int64_t elapsed() + { + auto end = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(end - start).count(); + } + + void reset() + { + start = std::chrono::steady_clock::now(); + } + + void print(const char* name) + { + auto end = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start).count(); + printf("%s: %ld us\n", name, elapsed); + } }; struct ScopedTimer { - Timer timer; - const char* name; - - ScopedTimer(const char* name) : name(name) - { - } - - ~ScopedTimer() - { - timer.print(name); - } + Timer timer; + const char* name; + + ScopedTimer(const char* name) : name(name) + { + } + + ~ScopedTimer() + { + timer.print(name); + } }; } // namespace mscclpp diff --git a/src/proxy_cpp.cc b/src/proxy_cpp.cc index b1626813..cd005e02 100644 --- a/src/proxy_cpp.cc +++ b/src/proxy_cpp.cc @@ -1,6 +1,6 @@ -#include "proxy.hpp" #include "api.h" #include "mscclpp.hpp" +#include "proxy.hpp" #include "utils.h" #include "utils.hpp" #include @@ -20,7 +20,8 @@ struct Proxy::Impl std::thread service; std::atomic_bool running; - Impl(ProxyHandler handler, std::function threadInit) : handler(handler), threadInit(threadInit), running(false) + Impl(ProxyHandler handler, std::function threadInit) + : handler(handler), threadInit(threadInit), running(false) { } }; @@ -45,7 +46,6 @@ MSCCLPP_API_CPP void Proxy::start() { pimpl->running = true; pimpl->service = std::thread([this] { - pimpl->threadInit(); ProxyHandler handler = this->pimpl->handler; @@ -109,4 +109,4 @@ MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() return pimpl->fifo; } -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/registered_memory.cc b/src/registered_memory.cc index fed732a0..3cb82fbf 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -88,7 +88,7 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() std::copy_n(reinterpret_cast(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); if (pimpl->transportInfos.size() > std::numeric_limits::max()) { - throw std::runtime_error("Too many transport info entries"); + throw mscclpp::Error("Too many transport info entries", mscclppInternalError); } int8_t transportCount = pimpl->transportInfos.size(); std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); @@ -102,7 +102,7 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() } else if (AllIBTransports.has(entry.transport)) { std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); } else { - throw std::runtime_error("Unknown transport"); + throw mscclpp::Error("Unknown transport", mscclppInternalError); } } return result; @@ -132,21 +132,23 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); it += sizeof(transportInfo.transport); if (transportInfo.transport == Transport::CudaIpc) { - std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast(&transportInfo.cudaIpcBaseHandle)); + std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), + reinterpret_cast(&transportInfo.cudaIpcBaseHandle)); it += sizeof(transportInfo.cudaIpcBaseHandle); - std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast(&transportInfo.cudaIpcOffsetFromBase)); + std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), + reinterpret_cast(&transportInfo.cudaIpcOffsetFromBase)); it += sizeof(transportInfo.cudaIpcOffsetFromBase); } else if (AllIBTransports.has(transportInfo.transport)) { std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast(&transportInfo.ibMrInfo)); it += sizeof(transportInfo.ibMrInfo); transportInfo.ibLocal = false; } else { - throw std::runtime_error("Unknown transport"); + throw mscclpp::Error("Unknown transport", mscclppInternalError); } this->transportInfos.push_back(transportInfo); } if (it != serialization.end()) { - throw std::runtime_error("Deserialization failed"); + throw mscclpp::Error("Serialization failed", mscclppInternalError); } if (transports.has(Transport::CudaIpc)) { diff --git a/src/utils.cc b/src/utils.cc index 6954a64f..d3957bb1 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -6,10 +6,10 @@ #include "utils.h" +#include #include #include #include -#include // Get current Compute Capability // int mscclppCudaCompCap() { diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index ad473f8f..ddfd51d8 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -1,5 +1,6 @@ #include "mscclpp.h" #include "mscclpp.hpp" + #include "channel.hpp" #ifdef MSCCLPP_USE_MPI_FOR_TESTS @@ -71,8 +72,8 @@ __device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int ra devChan.wait(); } -__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, - int remoteRank, uint64_t offset, uint64_t size) +__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, + int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size) { // this allgather algorithm works as follows: // Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode @@ -131,7 +132,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra // opposite side if ((threadIdx.x % 32) == 0) devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), - (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) devChan.wait(); } @@ -150,9 +151,8 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) - devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * - sizeof(int), - nelemsPerGPU / pipelineSize * sizeof(int)); + devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), + nelemsPerGPU / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) devChan.wait(); } @@ -226,7 +226,8 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice)); } -void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize) +void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, + mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize) { int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); @@ -258,12 +259,13 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co std::vector devChannels; for (size_t i = 0; i < channelIds.size(); ++i) { devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(channelService.deviceChannel(channelIds[i]), - channelService.addMemory(remoteMemories[i].get()), channelService.addMemory(localMemories[i]))); + channelService.addMemory(remoteMemories[i].get()), + channelService.addMemory(localMemories[i]))); } assert(devChannels.size() < sizeof(constDevChans) / sizeof(mscclpp::channel::SimpleDeviceChannel)); - CUDACHECK( - cudaMemcpyToSymbol(constDevChans, devChannels.data(), sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size())); + CUDACHECK(cudaMemcpyToSymbol(constDevChans, devChannels.data(), + sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size())); } void printUsage(const char* prog, bool isMpi) diff --git a/tests/communicator_test_cpp.cu b/tests/communicator_test_cpp.cu index 56c8592e..345ba1fc 100644 --- a/tests/communicator_test_cpp.cu +++ b/tests/communicator_test_cpp.cu @@ -1,5 +1,5 @@ -#include "mscclpp.hpp" #include "epoch.hpp" +#include "mscclpp.hpp" #include #include @@ -24,26 +24,33 @@ mscclpp::Transport findIb(int localRank) return IBs[localRank]; } -void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map& remoteMemory){ +void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, + size_t deviceBufferSize, mscclpp::Transport myIbDevice, + mscclpp::RegisteredMemory& localMemory, + std::unordered_map& remoteMemory) +{ localMemory = communicator.registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice); std::unordered_map> futureRemoteMemory; for (int i = 0; i < worldSize; i++) { - if (i != rank){ + if (i != rank) { communicator.sendMemoryOnSetup(localMemory, i, 0); futureRemoteMemory[i] = communicator.recvMemoryOnSetup(i, 0); } } communicator.setup(); for (int i = 0; i < worldSize; i++) { - if (i != rank){ + if (i != rank) { remoteMemory[i] = futureRemoteMemory[i].get(); } } } -void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map>& connections){ +void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, + mscclpp::Transport myIbDevice, + std::unordered_map>& connections) +{ for (int i = 0; i < worldSize; i++) { - if (i != rank){ + if (i != rank) { if (i / nRanksPerNode == rank / nRanksPerNode) { connections[i] = communicator.connectOnSetup(i, 0, mscclpp::Transport::CudaIpc); } else { @@ -54,35 +61,40 @@ void make_connections(mscclpp::Communicator& communicator, int rank, int worldSi communicator.setup(); } -void write_remote(int rank, int worldSize, std::unordered_map>& connections, - std::unordered_map& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank){ +void write_remote(int rank, int worldSize, std::unordered_map>& connections, + std::unordered_map& remoteRegisteredMemories, + mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank) +{ for (int i = 0; i < worldSize; i++) { if (i != rank) { auto& conn = connections.at(i); auto& peerMemory = remoteRegisteredMemories.at(i); - conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, rank * dataCountPerRank*sizeof(int), dataCountPerRank*sizeof(int)); + conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, + rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int)); conn->flush(); } } } -void device_buffer_init(int rank, int worldSize, int dataCount, std::vector& devicePtr){ - for (int n = 0; n < (int)devicePtr.size(); n++){ +void device_buffer_init(int rank, int worldSize, int dataCount, std::vector& devicePtr) +{ + for (int n = 0; n < (int)devicePtr.size(); n++) { std::vector hostBuffer(dataCount, 0); for (int i = 0; i < dataCount; i++) { hostBuffer[i] = rank + n * worldSize; } - CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount*sizeof(int), cudaMemcpyHostToDevice)); + CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount * sizeof(int), cudaMemcpyHostToDevice)); } CUDATHROW(cudaDeviceSynchronize()); } -bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector& devicePtr){ - for (int n = 0; n < (int)devicePtr.size(); n++){ +bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector& devicePtr) +{ + for (int n = 0; n < (int)devicePtr.size(); n++) { std::vector hostBuffer(dataCount, 0); - CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount*sizeof(int), cudaMemcpyDeviceToHost)); + CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount * sizeof(int), cudaMemcpyDeviceToHost)); for (int i = 0; i < worldSize; i++) { - for (int j = i*dataCount/worldSize; j < (i+1)*dataCount/worldSize; j++) { + for (int j = i * dataCount / worldSize; j < (i + 1) * dataCount / worldSize; j++) { if (hostBuffer[j] != i + n * worldSize) { return false; } @@ -92,8 +104,11 @@ bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vec return true; } -void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr bootstrap, std::unordered_map>& connections, - std::vector>& remoteMemory, std::vector& localMemory, std::vector& devicePtr, int numBuffers){ +void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr bootstrap, + std::unordered_map>& connections, + std::vector>& remoteMemory, + std::vector& localMemory, std::vector& devicePtr, int numBuffers) +{ assert((deviceBufferSize / sizeof(int)) % worldSize == 0); size_t dataCount = deviceBufferSize / sizeof(int); @@ -102,8 +117,8 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptrbarrier(); if (bootstrap->getRank() == 0) std::cout << "CUDA memory initialization passed" << std::endl; - - for (int n = 0; n < numBuffers; n++){ + + for (int n = 0; n < numBuffers; n++) { write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize); } bootstrap->barrier(); @@ -116,7 +131,7 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr bootstrap, std::unordered_map>& connections, - std::vector>& remoteMemory, std::vector& localMemory, std::vector& devicePtr, std::unordered_map> epochs, int numBuffers){ +void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, + std::shared_ptr bootstrap, + std::unordered_map>& connections, + std::vector>& remoteMemory, + std::vector& localMemory, std::vector& devicePtr, + std::unordered_map> epochs, int numBuffers) +{ assert((deviceBufferSize / sizeof(int)) % worldSize == 0); size_t dataCount = deviceBufferSize / sizeof(int); @@ -153,8 +175,8 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std:: mscclpp::DeviceEpoch* deviceEpochs; CUDATHROW(cudaMalloc(&deviceEpochs, sizeof(mscclpp::DeviceEpoch) * worldSize)); - for (int i = 0; i < worldSize; i++){ - if (i != rank){ + for (int i = 0; i < worldSize; i++) { + if (i != rank) { mscclpp::DeviceEpoch deviceEpoch = epochs[i]->deviceEpoch(); CUDATHROW(cudaMemcpy(&deviceEpochs[i], &deviceEpoch, sizeof(mscclpp::DeviceEpoch), cudaMemcpyHostToDevice)); } @@ -165,16 +187,15 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std:: if (bootstrap->getRank() == 0) std::cout << "CUDA device epochs are created" << std::endl; - - for (int n = 0; n < numBuffers; n++){ + for (int n = 0; n < numBuffers; n++) { write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize); } increament_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize); CUDATHROW(cudaDeviceSynchronize()); - for (int i = 0; i < worldSize; i++){ - if (i != rank){ + for (int i = 0; i < worldSize; i++) { + if (i != rank) { epochs[i]->signal(); } } @@ -182,13 +203,14 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std:: wait_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize); CUDATHROW(cudaDeviceSynchronize()); - if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)){ + if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)) { throw std::runtime_error("unexpected result."); } bootstrap->barrier(); if (bootstrap->getRank() == 0) - std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---" << std::endl; + std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---" + << std::endl; } void test_communicator(int rank, int worldSize, int nranksPerNode) @@ -213,8 +235,8 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) int numBuffers = 10; std::vector devicePtr(numBuffers); - int deviceBufferSize = 1024*1024; - + int deviceBufferSize = 1024 * 1024; + std::vector localMemory(numBuffers); std::vector> remoteMemory(numBuffers); @@ -222,13 +244,15 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (n % 100 == 0) std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl; CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize)); - register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]); + register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], + remoteMemory[n]); } bootstrap->barrier(); if (bootstrap->getRank() == 0) std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl; - test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, numBuffers); + test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, + numBuffers); if (bootstrap->getRank() == 0) std::cout << "--- Testing vanialla writes passed ---" << std::endl; @@ -242,12 +266,13 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Epochs are created" << std::endl; - test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, epochs, numBuffers); + test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, + devicePtr, epochs, numBuffers); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; - for (int n = 0; n < numBuffers; n++){ + for (int n = 0; n < numBuffers; n++) { CUDATHROW(cudaFree(devicePtr[n])); } } @@ -269,4 +294,4 @@ int main(int argc, char** argv) MPI_Finalize(); return 0; -} \ No newline at end of file +}