mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
4
Makefile
4
Makefile
@@ -121,7 +121,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
|
||||
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
|
||||
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
|
||||
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
|
||||
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc)
|
||||
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc errors.cc)
|
||||
ifneq ($(NPKIT), 0)
|
||||
LIBSRCS += $(addprefix src/misc/,npkit.cc)
|
||||
endif
|
||||
@@ -135,7 +135,7 @@ HEADERS := $(wildcard src/include/*.h)
|
||||
CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
|
||||
PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
|
||||
|
||||
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp
|
||||
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp errors.hpp
|
||||
INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%)
|
||||
|
||||
LIBNAME := libmscclpp.so
|
||||
|
||||
@@ -194,13 +194,15 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector<
|
||||
MSCCLPPTHROW(mscclppSocketClose(&sock));
|
||||
|
||||
if (this->nRanks_ != info.nRanks) {
|
||||
throw std::runtime_error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) +
|
||||
" : " + std::to_string(info.nRanks));
|
||||
throw mscclpp::Error("Bootstrap Root : mismatch in rank count from procs " + std::to_string(this->nRanks_) + " : " +
|
||||
std::to_string(info.nRanks),
|
||||
mscclppInternalError);
|
||||
}
|
||||
|
||||
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
|
||||
throw std::runtime_error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " +
|
||||
std::to_string(this->nRanks_) + " has already checked in");
|
||||
throw mscclpp::Error("Bootstrap Root : rank " + std::to_string(info.rank) + " of " + std::to_string(this->nRanks_) +
|
||||
" has already checked in",
|
||||
mscclppInternalError);
|
||||
}
|
||||
|
||||
// Save the connection handle for that rank
|
||||
@@ -269,16 +271,17 @@ void Bootstrap::Impl::netInit(std::string ipPortPair)
|
||||
if (!ipPortPair.empty()) {
|
||||
mscclppSocketAddress remoteAddr;
|
||||
if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) {
|
||||
throw std::runtime_error(
|
||||
"Invalid ipPortPair, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
|
||||
throw mscclpp::Error(
|
||||
"Invalid ipPortPair, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>",
|
||||
mscclppInvalidArgument);
|
||||
}
|
||||
if (mscclppFindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
|
||||
throw std::runtime_error("NET/Socket : No usable listening interface found");
|
||||
throw mscclpp::Error("NET/Socket : No usable listening interface found", mscclppInternalError);
|
||||
}
|
||||
} else {
|
||||
int ret = mscclppFindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
|
||||
if (ret <= 0) {
|
||||
throw std::runtime_error("Bootstrap : no socket interface found");
|
||||
throw mscclpp::Error("Bootstrap : no socket interface found", mscclppInternalError);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,8 +393,9 @@ void Bootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size)
|
||||
int recvSize;
|
||||
MSCCLPPTHROW(mscclppSocketRecv(sock, &recvSize, sizeof(int)));
|
||||
if (recvSize > size) {
|
||||
throw std::runtime_error("Message truncated : received " + std::to_string(recvSize) + " bytes instead of " +
|
||||
std::to_string(size));
|
||||
throw mscclpp::Error("Message truncated : received " + std::to_string(recvSize) + " bytes instead of " +
|
||||
std::to_string(size),
|
||||
mscclppInternalError);
|
||||
}
|
||||
MSCCLPPTHROW(mscclppSocketRecv(sock, data, std::min(recvSize, size)));
|
||||
}
|
||||
@@ -1058,4 +1062,4 @@ mscclppResult_t bootstrapAbort(void* commState)
|
||||
free(state->peerProxyAddresses);
|
||||
free(state);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
#include "channel.hpp"
|
||||
#include "utils.h"
|
||||
#include "checks.hpp"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "debug.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace channel {
|
||||
|
||||
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator),
|
||||
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
|
||||
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator)
|
||||
: communicator_(communicator),
|
||||
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); })
|
||||
{
|
||||
int cudaDevice;
|
||||
CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode));
|
||||
@@ -23,4 +25,4 @@ MSCCLPP_API_CPP void DeviceChannelService::bindThread()
|
||||
}
|
||||
|
||||
} // namespace channel
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -59,8 +59,9 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
|
||||
|
||||
struct MemorySender : public Setuppable
|
||||
{
|
||||
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
|
||||
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
|
||||
MemorySender(RegisteredMemory memory, int remoteRank, int tag) : memory_(memory), remoteRank_(remoteRank), tag_(tag)
|
||||
{
|
||||
}
|
||||
|
||||
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
|
||||
{
|
||||
@@ -79,8 +80,9 @@ MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, in
|
||||
|
||||
struct MemoryReceiver : public Setuppable
|
||||
{
|
||||
MemoryReceiver(int remoteRank, int tag)
|
||||
: remoteRank_(remoteRank), tag_(tag) {}
|
||||
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag)
|
||||
{
|
||||
}
|
||||
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
|
||||
{
|
||||
@@ -112,7 +114,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
|
||||
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"
|
||||
<< " != " << pimpl->bootstrap_->getRank() << "(" << std::hex
|
||||
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
|
||||
throw std::runtime_error(ss.str());
|
||||
throw mscclpp::Error(ss.str(), mscclppInternalError);
|
||||
}
|
||||
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag);
|
||||
conn = cudaIpcConn;
|
||||
@@ -126,7 +128,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int rem
|
||||
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported transport");
|
||||
throw mscclpp::Error("Unsupported transport", mscclppInvalidArgument);
|
||||
}
|
||||
pimpl->connections_.push_back(conn);
|
||||
addSetup(conn);
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
#include <algorithm>
|
||||
#include "connection.hpp"
|
||||
#include "checks.hpp"
|
||||
#include "infiniband/verbs.h"
|
||||
#include "npkit/npkit.h"
|
||||
#include "registered_memory.hpp"
|
||||
#include "utils.hpp"
|
||||
#include <algorithm>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
void validateTransport(RegisteredMemory mem, Transport transport)
|
||||
{
|
||||
if (!mem.transports().has(transport)) {
|
||||
throw std::runtime_error("mem does not support transport");
|
||||
throw Error("RegisteredMemory does not support transport", mscclppInvalidArgument);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,11 +24,19 @@ std::shared_ptr<RegisteredMemory::Impl> Connection::getRegisteredMemoryImpl(Regi
|
||||
|
||||
// ConnectionBase
|
||||
|
||||
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
|
||||
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag)
|
||||
{
|
||||
}
|
||||
|
||||
int ConnectionBase::remoteRank() { return remoteRank_; }
|
||||
int ConnectionBase::remoteRank()
|
||||
{
|
||||
return remoteRank_;
|
||||
}
|
||||
|
||||
int ConnectionBase::tag() { return tag_; }
|
||||
int ConnectionBase::tag()
|
||||
{
|
||||
return tag_;
|
||||
}
|
||||
|
||||
// CudaIpcConnection
|
||||
|
||||
@@ -99,11 +107,11 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
|
||||
|
||||
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
|
||||
if (dstTransportInfo.ibLocal) {
|
||||
throw std::runtime_error("dst is local, which is not supported");
|
||||
throw Error("dst is local, which is not supported", mscclppInvalidArgument);
|
||||
}
|
||||
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport());
|
||||
if (!srcTransportInfo.ibLocal) {
|
||||
throw std::runtime_error("src is remote, which is not supported");
|
||||
throw Error("src is remote, which is not supported", mscclppInvalidArgument);
|
||||
}
|
||||
|
||||
auto dstMrInfo = dstTransportInfo.ibMrInfo;
|
||||
@@ -113,7 +121,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
|
||||
/*signaled=*/true);
|
||||
numSignaledSends++;
|
||||
qp->postSend();
|
||||
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, (uint8_t*)dstMrInfo.addr + dstOffset, size);
|
||||
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset,
|
||||
(uint8_t*)dstMrInfo.addr + dstOffset, size);
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
@@ -123,16 +132,19 @@ void IBConnection::flush()
|
||||
while (numSignaledSends) {
|
||||
int wcNum = qp->pollCq();
|
||||
if (wcNum < 0) {
|
||||
throw std::runtime_error("pollCq failed: error no " + std::to_string(errno));
|
||||
throw mscclpp::IbError("pollCq failed: error no " + std::to_string(errno), errno);
|
||||
}
|
||||
|
||||
auto elapsed = timer.elapsed();
|
||||
if (elapsed > MSCCLPP_POLLING_WAIT)
|
||||
throw std::runtime_error("pollCq is stuck: waited for " + std::to_string(elapsed) + " seconds. Expected " + std::to_string(numSignaledSends) + " signals");
|
||||
if (elapsed > MSCCLPP_POLLING_WAIT) {
|
||||
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed) + " seconds. Expected " +
|
||||
std::to_string(numSignaledSends) + " signals",
|
||||
mscclppInternalError);
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status));
|
||||
throw mscclpp::IbError("pollCq failed: status " + std::to_string(wc->status), wc->status);
|
||||
}
|
||||
if (wc->opcode == IBV_WC_RDMA_WRITE) {
|
||||
numSignaledSends--;
|
||||
|
||||
18
src/epoch.cc
18
src/epoch.cc
@@ -1,26 +1,32 @@
|
||||
#include "epoch.hpp"
|
||||
#include "checks.hpp"
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection) : connection_(connection) {
|
||||
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: connection_(connection)
|
||||
{
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1));
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1));
|
||||
|
||||
localEpochIdsRegMem_ = communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport());
|
||||
localEpochIdsRegMem_ =
|
||||
communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport());
|
||||
communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection->remoteRank(), connection->tag());
|
||||
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Epoch::~Epoch() {
|
||||
MSCCLPP_API_CPP Epoch::~Epoch()
|
||||
{
|
||||
mscclppCudaFree(device_.epochIds_);
|
||||
mscclppCudaFree(device_.expectedInboundEpochId_);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Epoch::signal() {
|
||||
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
|
||||
MSCCLPP_API_CPP void Epoch::signal()
|
||||
{
|
||||
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_,
|
||||
offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
30
src/errors.cc
Normal file
30
src/errors.cc
Normal file
@@ -0,0 +1,30 @@
|
||||
#include "errors.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
BaseError::BaseError(std::string message, int errorCode) : std::runtime_error(message), errorCode_(errorCode)
|
||||
{
|
||||
}
|
||||
|
||||
int BaseError::getErrorCode() const
|
||||
{
|
||||
return errorCode_;
|
||||
}
|
||||
|
||||
Error::Error(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
|
||||
CudaError::CudaError(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
|
||||
CuError::CuError(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
|
||||
IbError::IbError(std::string message, int errorCode) : BaseError(message, errorCode)
|
||||
{
|
||||
}
|
||||
|
||||
}; // namespace mscclpp
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "mscclppfifo.hpp"
|
||||
#include "api.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <emmintrin.h>
|
||||
#include <stdexcept>
|
||||
|
||||
44
src/ib.cc
44
src/ib.cc
@@ -6,12 +6,12 @@
|
||||
#include <unistd.h>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include "api.h"
|
||||
#include <infiniband/verbs.h>
|
||||
#include <string>
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace mscclpp {
|
||||
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
|
||||
{
|
||||
if (size == 0) {
|
||||
throw std::runtime_error("invalid size: " + std::to_string(size));
|
||||
throw std::invalid_argument("invalid size: " + std::to_string(size));
|
||||
}
|
||||
static __thread uintptr_t pageSize = 0;
|
||||
if (pageSize == 0) {
|
||||
@@ -35,7 +35,7 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
|
||||
if (_mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_mr failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->mr = _mr;
|
||||
this->size = pages * pageSize;
|
||||
@@ -73,7 +73,7 @@ IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
if (this->cq == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_cq failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
|
||||
struct ibv_qp_init_attr qpInitAttr;
|
||||
@@ -92,14 +92,14 @@ IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
if (_qp == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(_ctx, port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->info.lid = portAttr.lid;
|
||||
this->info.port = port;
|
||||
@@ -111,7 +111,7 @@ IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
if (ibv_query_gid(_ctx, port, 0, &gid) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_gid failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->info.spn = gid.global.subnet_prefix;
|
||||
}
|
||||
@@ -125,7 +125,7 @@ IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->qp = _qp;
|
||||
this->wrn = 0;
|
||||
@@ -174,7 +174,7 @@ void IbQp::rtr(const IbQpInfo& info)
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ void IbQp::rts()
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,7 +249,7 @@ void IbQp::postSend()
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_send failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->wrn = 0;
|
||||
}
|
||||
@@ -265,7 +265,7 @@ void IbQp::postRecv(uint64_t wrId)
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_recv failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -299,13 +299,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName)
|
||||
if (this->ctx == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
this->pd = ibv_alloc_pd(reinterpret_cast<struct ibv_context*>(this->ctx));
|
||||
if (this->pd == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_alloc_pd failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,7 +327,7 @@ bool IbCtx::isPortUsable(int port) const
|
||||
if (ibv_query_port(reinterpret_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
return portAttr.state == IBV_PORT_ACTIVE &&
|
||||
(portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
|
||||
@@ -339,7 +339,7 @@ int IbCtx::getAnyActivePort() const
|
||||
if (ibv_query_device(reinterpret_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
}
|
||||
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
|
||||
if (this->isPortUsable(port)) {
|
||||
@@ -354,10 +354,10 @@ IbQp* IbCtx::createQp(int port /*=-1*/)
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
throw std::runtime_error("No active port found");
|
||||
throw mscclpp::Error("No active port found", mscclppInternalError);
|
||||
}
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw std::runtime_error("invalid IB port: " + std::to_string(port));
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), mscclppInternalError);
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
|
||||
return qps.back().get();
|
||||
@@ -412,10 +412,10 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport)
|
||||
ibTransportIndex = 7;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not an IB transport");
|
||||
throw std::invalid_argument("Not an IB transport");
|
||||
}
|
||||
if (ibTransportIndex >= num) {
|
||||
throw std::runtime_error("IB transport out of range");
|
||||
throw std::out_of_range("IB transport out of range");
|
||||
}
|
||||
return devices[ibTransportIndex]->name;
|
||||
}
|
||||
@@ -444,11 +444,11 @@ MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDevice
|
||||
case 7:
|
||||
return Transport::IB7;
|
||||
default:
|
||||
throw std::runtime_error("IB device index out of range");
|
||||
throw std::out_of_range("IB device index out of range");
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("IB device not found");
|
||||
throw std::invalid_argument("IB device not found");
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#include "epoch.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "mscclppfifo.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -15,10 +15,16 @@ class Channel
|
||||
{
|
||||
public:
|
||||
Channel(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: connection_(connection), epoch_(std::make_shared<Epoch>(communicator, connection)) {};
|
||||
: connection_(connection), epoch_(std::make_shared<Epoch>(communicator, connection)){};
|
||||
|
||||
Connection& connection() { return *connection_; }
|
||||
Epoch& epoch() { return *epoch_; }
|
||||
Connection& connection()
|
||||
{
|
||||
return *connection_;
|
||||
}
|
||||
Epoch& epoch()
|
||||
{
|
||||
return *epoch_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
@@ -69,8 +75,8 @@ union ChannelTrigger {
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value)
|
||||
{
|
||||
}
|
||||
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size, int connectionId)
|
||||
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size, int connectionId)
|
||||
{
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
|
||||
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
|
||||
@@ -86,15 +92,17 @@ struct DeviceChannel
|
||||
{
|
||||
DeviceChannel() = default;
|
||||
|
||||
DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo) : channelId_(channelId), epoch_(epoch), fifo_(fifo) {}
|
||||
DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo)
|
||||
: channelId_(channelId), epoch_(epoch), fifo_(fifo)
|
||||
{
|
||||
}
|
||||
|
||||
DeviceChannel(const DeviceChannel& other) = default;
|
||||
|
||||
DeviceChannel& operator=(DeviceChannel& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
}
|
||||
@@ -110,13 +118,11 @@ struct DeviceChannel
|
||||
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size)
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_)
|
||||
.value);
|
||||
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
|
||||
@@ -128,16 +134,14 @@ struct DeviceChannel
|
||||
uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst,
|
||||
dstOffset, src, srcOffset, size, channelId_)
|
||||
.value);
|
||||
uint64_t curFifoHead = fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo_.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset,
|
||||
uint64_t size)
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
@@ -176,25 +180,40 @@ class DeviceChannelService;
|
||||
|
||||
inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService);
|
||||
|
||||
class DeviceChannelService {
|
||||
class DeviceChannelService
|
||||
{
|
||||
public:
|
||||
DeviceChannelService(Communicator& communicator);
|
||||
|
||||
ChannelId addChannel(std::shared_ptr<Connection> connection) {
|
||||
ChannelId addChannel(std::shared_ptr<Connection> connection)
|
||||
{
|
||||
channels_.push_back(Channel(communicator_, connection));
|
||||
return channels_.size() - 1;
|
||||
}
|
||||
|
||||
MemoryId addMemory(RegisteredMemory memory) {
|
||||
MemoryId addMemory(RegisteredMemory memory)
|
||||
{
|
||||
memories_.push_back(memory);
|
||||
return memories_.size() - 1;
|
||||
}
|
||||
|
||||
Channel channel(ChannelId id) { return channels_[id]; }
|
||||
DeviceChannel deviceChannel(ChannelId id) { return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo()); }
|
||||
Channel channel(ChannelId id)
|
||||
{
|
||||
return channels_[id];
|
||||
}
|
||||
DeviceChannel deviceChannel(ChannelId id)
|
||||
{
|
||||
return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo());
|
||||
}
|
||||
|
||||
void startProxy() { proxy_.start(); }
|
||||
void stopProxy() { proxy_.stop(); }
|
||||
void startProxy()
|
||||
{
|
||||
proxy_.start();
|
||||
}
|
||||
void stopProxy()
|
||||
{
|
||||
proxy_.stop();
|
||||
}
|
||||
|
||||
private:
|
||||
Communicator& communicator_;
|
||||
@@ -205,7 +224,8 @@ private:
|
||||
|
||||
void bindThread();
|
||||
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) {
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw)
|
||||
{
|
||||
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
|
||||
Channel& channel = channels_[trigger->fields.chanId];
|
||||
|
||||
@@ -234,7 +254,9 @@ struct SimpleDeviceChannel
|
||||
{
|
||||
SimpleDeviceChannel() = default;
|
||||
|
||||
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}
|
||||
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src)
|
||||
{
|
||||
}
|
||||
|
||||
SimpleDeviceChannel(const SimpleDeviceChannel& other) = default;
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#define MSCCLPP_CHECKS_HPP_
|
||||
|
||||
#include "debug.h"
|
||||
#include "errors.hpp"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@@ -15,7 +17,8 @@
|
||||
do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \
|
||||
throw mscclpp::Error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res), \
|
||||
res); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
@@ -23,7 +26,7 @@
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \
|
||||
throw mscclpp::CudaError(std::string("Cuda failure '") + cudaGetErrorString(err) + "'", err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
@@ -33,7 +36,7 @@
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
const char* errStr; \
|
||||
cuGetErrorString(err, &errStr); \
|
||||
throw std::runtime_error(std::string("Cu failure '") + std::string(errStr) + "'"); \
|
||||
throw mscclpp::CuError(std::string("Cu failure '") + std::string(errStr) + "'", err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class ConnectionBase : public Connection, public Setuppable
|
||||
{
|
||||
int remoteRank_;
|
||||
int tag_;
|
||||
|
||||
public:
|
||||
ConnectionBase(int remoteRank, int tag);
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ struct DeviceEpoch
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*expectedInboundEpochId_) += 1;
|
||||
while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_));
|
||||
while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
@@ -44,9 +45,12 @@ public:
|
||||
|
||||
void signal();
|
||||
|
||||
DeviceEpoch deviceEpoch() { return device_; }
|
||||
DeviceEpoch deviceEpoch()
|
||||
{
|
||||
return device_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EPOCH_HPP_
|
||||
#endif // MSCCLPP_EPOCH_HPP_
|
||||
|
||||
46
src/include/errors.hpp
Normal file
46
src/include/errors.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef MSCCLPP_ERRORS_HPP_
|
||||
#define MSCCLPP_ERRORS_HPP_
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mscclpp {
|
||||
class BaseError : public std::runtime_error
|
||||
{
|
||||
public:
|
||||
BaseError(std::string message, int errorCode);
|
||||
virtual ~BaseError() = default;
|
||||
int getErrorCode() const;
|
||||
|
||||
private:
|
||||
int errorCode_;
|
||||
};
|
||||
|
||||
class Error : public BaseError
|
||||
{
|
||||
public:
|
||||
Error(std::string message, int errorCode);
|
||||
virtual ~Error() = default;
|
||||
};
|
||||
|
||||
class CudaError : public BaseError
|
||||
{
|
||||
public:
|
||||
CudaError(std::string message, int errorCode);
|
||||
virtual ~CudaError() = default;
|
||||
};
|
||||
|
||||
class CuError : public BaseError
|
||||
{
|
||||
public:
|
||||
CuError(std::string message, int errorCode);
|
||||
virtual ~CuError() = default;
|
||||
};
|
||||
|
||||
class IbError : public BaseError
|
||||
{
|
||||
public:
|
||||
IbError(std::string message, int errorCode);
|
||||
virtual ~IbError() = default;
|
||||
};
|
||||
}; // namespace mscclpp
|
||||
#endif // MSCCLPP_ERRORS_HPP
|
||||
@@ -7,10 +7,10 @@
|
||||
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
|
||||
|
||||
#include <bitset>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <future>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -37,14 +37,14 @@ public:
|
||||
{
|
||||
size_t size = data.size();
|
||||
send((void*)&size, sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag+1);
|
||||
send((void*)data.data(), data.size(), peer, tag + 1);
|
||||
}
|
||||
void recv(std::vector<char>& data, int peer, int tag)
|
||||
{
|
||||
size_t size;
|
||||
recv((void*)&size, sizeof(size_t), peer, tag);
|
||||
data.resize(size);
|
||||
recv((void*)data.data(), data.size(), peer, tag+1);
|
||||
recv((void*)data.data(), data.size(), peer, tag + 1);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -239,7 +239,8 @@ class Connection;
|
||||
class RegisteredMemory
|
||||
{
|
||||
struct Impl;
|
||||
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated lazily.
|
||||
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
|
||||
// lazily.
|
||||
std::shared_ptr<Impl> pimpl;
|
||||
|
||||
public:
|
||||
@@ -281,17 +282,23 @@ protected:
|
||||
|
||||
struct Setuppable
|
||||
{
|
||||
virtual void beginSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
virtual void beginSetup(std::shared_ptr<BaseBootstrap>)
|
||||
{
|
||||
}
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap>)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class NonblockingFuture
|
||||
template <typename T> class NonblockingFuture
|
||||
{
|
||||
std::shared_future<T> future;
|
||||
|
||||
public:
|
||||
NonblockingFuture() = default;
|
||||
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}
|
||||
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future))
|
||||
{
|
||||
}
|
||||
NonblockingFuture(const NonblockingFuture&) = default;
|
||||
|
||||
bool ready() const
|
||||
@@ -331,7 +338,7 @@ public:
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
|
||||
|
||||
|
||||
void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag);
|
||||
|
||||
NonblockingFuture<RegisteredMemory> recvMemoryOnSetup(int remoteRank, int tag);
|
||||
@@ -363,7 +370,6 @@ public:
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
namespace std {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
|
||||
#include "communicator.hpp"
|
||||
#include "errors.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
@@ -16,11 +17,13 @@ struct TransportInfo
|
||||
// TODO: rewrite this using std::variant or something
|
||||
bool ibLocal;
|
||||
union {
|
||||
struct {
|
||||
struct
|
||||
{
|
||||
cudaIpcMemHandle_t cudaIpcBaseHandle;
|
||||
size_t cudaIpcOffsetFromBase;
|
||||
};
|
||||
struct {
|
||||
struct
|
||||
{
|
||||
const IbMr* ibMr;
|
||||
IbMrInfo ibMrInfo;
|
||||
};
|
||||
@@ -46,7 +49,7 @@ struct RegisteredMemory::Impl
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Transport data not found");
|
||||
throw Error("Transport data not found", mscclppInternalError);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -8,45 +8,45 @@ namespace mscclpp {
|
||||
|
||||
struct Timer
|
||||
{
|
||||
std::chrono::steady_clock::time_point start;
|
||||
|
||||
Timer()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
int64_t elapsed()
|
||||
{
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
}
|
||||
|
||||
void reset()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void print(const char* name)
|
||||
{
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
printf("%s: %ld us\n", name, elapsed);
|
||||
}
|
||||
std::chrono::steady_clock::time_point start;
|
||||
|
||||
Timer()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
int64_t elapsed()
|
||||
{
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
}
|
||||
|
||||
void reset()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void print(const char* name)
|
||||
{
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
printf("%s: %ld us\n", name, elapsed);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScopedTimer
|
||||
{
|
||||
Timer timer;
|
||||
const char* name;
|
||||
|
||||
ScopedTimer(const char* name) : name(name)
|
||||
{
|
||||
}
|
||||
|
||||
~ScopedTimer()
|
||||
{
|
||||
timer.print(name);
|
||||
}
|
||||
Timer timer;
|
||||
const char* name;
|
||||
|
||||
ScopedTimer(const char* name) : name(name)
|
||||
{
|
||||
}
|
||||
|
||||
~ScopedTimer()
|
||||
{
|
||||
timer.print(name);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "proxy.hpp"
|
||||
#include "api.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "utils.h"
|
||||
#include "utils.hpp"
|
||||
#include <atomic>
|
||||
@@ -20,7 +20,8 @@ struct Proxy::Impl
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit) : handler(handler), threadInit(threadInit), running(false)
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit)
|
||||
: handler(handler), threadInit(threadInit), running(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -45,7 +46,6 @@ MSCCLPP_API_CPP void Proxy::start()
|
||||
{
|
||||
pimpl->running = true;
|
||||
pimpl->service = std::thread([this] {
|
||||
|
||||
pimpl->threadInit();
|
||||
|
||||
ProxyHandler handler = this->pimpl->handler;
|
||||
@@ -109,4 +109,4 @@ MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo()
|
||||
return pimpl->fifo;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -88,7 +88,7 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
|
||||
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
|
||||
throw std::runtime_error("Too many transport info entries");
|
||||
throw mscclpp::Error("Too many transport info entries", mscclppInternalError);
|
||||
}
|
||||
int8_t transportCount = pimpl->transportInfos.size();
|
||||
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
|
||||
@@ -102,7 +102,7 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
|
||||
} else if (AllIBTransports.has(entry.transport)) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
|
||||
} else {
|
||||
throw std::runtime_error("Unknown transport");
|
||||
throw mscclpp::Error("Unknown transport", mscclppInternalError);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@@ -132,21 +132,23 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
||||
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
|
||||
it += sizeof(transportInfo.transport);
|
||||
if (transportInfo.transport == Transport::CudaIpc) {
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle),
|
||||
reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
|
||||
it += sizeof(transportInfo.cudaIpcBaseHandle);
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase),
|
||||
reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
|
||||
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
|
||||
} else if (AllIBTransports.has(transportInfo.transport)) {
|
||||
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
|
||||
it += sizeof(transportInfo.ibMrInfo);
|
||||
transportInfo.ibLocal = false;
|
||||
} else {
|
||||
throw std::runtime_error("Unknown transport");
|
||||
throw mscclpp::Error("Unknown transport", mscclppInternalError);
|
||||
}
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if (it != serialization.end()) {
|
||||
throw std::runtime_error("Deserialization failed");
|
||||
throw mscclpp::Error("Serialization failed", mscclppInternalError);
|
||||
}
|
||||
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#include <memory>
|
||||
#include <numa.h>
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
// Get current Compute Capability
|
||||
// int mscclppCudaCompCap() {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
#include "channel.hpp"
|
||||
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
@@ -71,8 +72,8 @@ __device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int ra
|
||||
devChan.wait();
|
||||
}
|
||||
|
||||
__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
|
||||
int remoteRank, uint64_t offset, uint64_t size)
|
||||
__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size,
|
||||
int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size)
|
||||
{
|
||||
// this allgather algorithm works as follows:
|
||||
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
|
||||
@@ -131,7 +132,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.wait();
|
||||
}
|
||||
@@ -150,9 +151,8 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
|
||||
sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.wait();
|
||||
}
|
||||
@@ -226,7 +226,8 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz
|
||||
CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize)
|
||||
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm,
|
||||
mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize)
|
||||
{
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
@@ -258,12 +259,13 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
std::vector<mscclpp::channel::SimpleDeviceChannel> devChannels;
|
||||
for (size_t i = 0; i < channelIds.size(); ++i) {
|
||||
devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(channelService.deviceChannel(channelIds[i]),
|
||||
channelService.addMemory(remoteMemories[i].get()), channelService.addMemory(localMemories[i])));
|
||||
channelService.addMemory(remoteMemories[i].get()),
|
||||
channelService.addMemory(localMemories[i])));
|
||||
}
|
||||
|
||||
assert(devChannels.size() < sizeof(constDevChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
|
||||
CUDACHECK(
|
||||
cudaMemcpyToSymbol(constDevChans, devChannels.data(), sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size()));
|
||||
CUDACHECK(cudaMemcpyToSymbol(constDevChans, devChannels.data(),
|
||||
sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size()));
|
||||
}
|
||||
|
||||
void printUsage(const char* prog, bool isMpi)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include "epoch.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda_runtime.h>
|
||||
@@ -24,26 +24,33 @@ mscclpp::Transport findIb(int localRank)
|
||||
return IBs[localRank];
|
||||
}
|
||||
|
||||
void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemory){
|
||||
void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr,
|
||||
size_t deviceBufferSize, mscclpp::Transport myIbDevice,
|
||||
mscclpp::RegisteredMemory& localMemory,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemory)
|
||||
{
|
||||
localMemory = communicator.registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice);
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemory;
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank){
|
||||
if (i != rank) {
|
||||
communicator.sendMemoryOnSetup(localMemory, i, 0);
|
||||
futureRemoteMemory[i] = communicator.recvMemoryOnSetup(i, 0);
|
||||
}
|
||||
}
|
||||
communicator.setup();
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank){
|
||||
if (i != rank) {
|
||||
remoteMemory[i] = futureRemoteMemory[i].get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections){
|
||||
void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode,
|
||||
mscclpp::Transport myIbDevice,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections)
|
||||
{
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank){
|
||||
if (i != rank) {
|
||||
if (i / nRanksPerNode == rank / nRanksPerNode) {
|
||||
connections[i] = communicator.connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
|
||||
} else {
|
||||
@@ -54,35 +61,40 @@ void make_connections(mscclpp::Communicator& communicator, int rank, int worldSi
|
||||
communicator.setup();
|
||||
}
|
||||
|
||||
void write_remote(int rank, int worldSize, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank){
|
||||
void write_remote(int rank, int worldSize, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteRegisteredMemories,
|
||||
mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank)
|
||||
{
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank) {
|
||||
auto& conn = connections.at(i);
|
||||
auto& peerMemory = remoteRegisteredMemories.at(i);
|
||||
conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, rank * dataCountPerRank*sizeof(int), dataCountPerRank*sizeof(int));
|
||||
conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory,
|
||||
rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int));
|
||||
conn->flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void device_buffer_init(int rank, int worldSize, int dataCount, std::vector<int*>& devicePtr){
|
||||
for (int n = 0; n < (int)devicePtr.size(); n++){
|
||||
void device_buffer_init(int rank, int worldSize, int dataCount, std::vector<int*>& devicePtr)
|
||||
{
|
||||
for (int n = 0; n < (int)devicePtr.size(); n++) {
|
||||
std::vector<int> hostBuffer(dataCount, 0);
|
||||
for (int i = 0; i < dataCount; i++) {
|
||||
hostBuffer[i] = rank + n * worldSize;
|
||||
}
|
||||
CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount*sizeof(int), cudaMemcpyHostToDevice));
|
||||
CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount * sizeof(int), cudaMemcpyHostToDevice));
|
||||
}
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector<int*>& devicePtr){
|
||||
for (int n = 0; n < (int)devicePtr.size(); n++){
|
||||
bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector<int*>& devicePtr)
|
||||
{
|
||||
for (int n = 0; n < (int)devicePtr.size(); n++) {
|
||||
std::vector<int> hostBuffer(dataCount, 0);
|
||||
CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount*sizeof(int), cudaMemcpyDeviceToHost));
|
||||
CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount * sizeof(int), cudaMemcpyDeviceToHost));
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
for (int j = i*dataCount/worldSize; j < (i+1)*dataCount/worldSize; j++) {
|
||||
for (int j = i * dataCount / worldSize; j < (i + 1) * dataCount / worldSize; j++) {
|
||||
if (hostBuffer[j] != i + n * worldSize) {
|
||||
return false;
|
||||
}
|
||||
@@ -92,8 +104,11 @@ bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vec
|
||||
return true;
|
||||
}
|
||||
|
||||
void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, int numBuffers){
|
||||
void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
|
||||
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, int numBuffers)
|
||||
{
|
||||
|
||||
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
|
||||
size_t dataCount = deviceBufferSize / sizeof(int);
|
||||
@@ -102,8 +117,8 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "CUDA memory initialization passed" << std::endl;
|
||||
|
||||
for (int n = 0; n < numBuffers; n++){
|
||||
|
||||
for (int n = 0; n < numBuffers; n++) {
|
||||
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
|
||||
}
|
||||
bootstrap->barrier();
|
||||
@@ -116,7 +131,7 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
|
||||
do {
|
||||
ready = test_device_buffer_write_correctness(worldSize, dataCount, devicePtr);
|
||||
niter++;
|
||||
if (niter == 10000){
|
||||
if (niter == 10000) {
|
||||
throw std::runtime_error("Polling is stuck.");
|
||||
}
|
||||
} while (!ready);
|
||||
@@ -126,22 +141,29 @@ void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<m
|
||||
std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
}
|
||||
|
||||
__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
|
||||
__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize)
|
||||
{
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize){
|
||||
if (tid != rank && tid < worldSize) {
|
||||
deviceEpochs[tid].epochIncrement();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
|
||||
__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize)
|
||||
{
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize){
|
||||
if (tid != rank && tid < worldSize) {
|
||||
deviceEpochs[tid].wait();
|
||||
}
|
||||
}
|
||||
|
||||
void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs, int numBuffers){
|
||||
void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize,
|
||||
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory,
|
||||
std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs, int numBuffers)
|
||||
{
|
||||
|
||||
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
|
||||
size_t dataCount = deviceBufferSize / sizeof(int);
|
||||
@@ -153,8 +175,8 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::
|
||||
|
||||
mscclpp::DeviceEpoch* deviceEpochs;
|
||||
CUDATHROW(cudaMalloc(&deviceEpochs, sizeof(mscclpp::DeviceEpoch) * worldSize));
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank) {
|
||||
mscclpp::DeviceEpoch deviceEpoch = epochs[i]->deviceEpoch();
|
||||
CUDATHROW(cudaMemcpy(&deviceEpochs[i], &deviceEpoch, sizeof(mscclpp::DeviceEpoch), cudaMemcpyHostToDevice));
|
||||
}
|
||||
@@ -165,16 +187,15 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "CUDA device epochs are created" << std::endl;
|
||||
|
||||
|
||||
for (int n = 0; n < numBuffers; n++){
|
||||
for (int n = 0; n < numBuffers; n++) {
|
||||
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
|
||||
}
|
||||
|
||||
increament_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank) {
|
||||
epochs[i]->signal();
|
||||
}
|
||||
}
|
||||
@@ -182,13 +203,14 @@ void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::
|
||||
wait_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)){
|
||||
if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)) {
|
||||
throw std::runtime_error("unexpected result.");
|
||||
}
|
||||
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---" << std::endl;
|
||||
std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
@@ -213,8 +235,8 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
|
||||
int numBuffers = 10;
|
||||
std::vector<int*> devicePtr(numBuffers);
|
||||
int deviceBufferSize = 1024*1024;
|
||||
|
||||
int deviceBufferSize = 1024 * 1024;
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> localMemory(numBuffers);
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remoteMemory(numBuffers);
|
||||
|
||||
@@ -222,13 +244,15 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
if (n % 100 == 0)
|
||||
std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl;
|
||||
CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize));
|
||||
register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]);
|
||||
register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n],
|
||||
remoteMemory[n]);
|
||||
}
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
|
||||
test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, numBuffers);
|
||||
test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr,
|
||||
numBuffers);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing vanialla writes passed ---" << std::endl;
|
||||
|
||||
@@ -242,12 +266,13 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Epochs are created" << std::endl;
|
||||
|
||||
test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, epochs, numBuffers);
|
||||
test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory,
|
||||
devicePtr, epochs, numBuffers);
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
|
||||
|
||||
for (int n = 0; n < numBuffers; n++){
|
||||
for (int n = 0; n < numBuffers; n++) {
|
||||
CUDATHROW(cudaFree(devicePtr[n]));
|
||||
}
|
||||
}
|
||||
@@ -269,4 +294,4 @@ int main(int argc, char** argv)
|
||||
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user