Merge pull request #66 from microsoft/olli/api-extension

Olli/api extension
This commit is contained in:
Saeed Maleki
2023-05-03 19:47:03 -07:00
committed by GitHub
45 changed files with 2689 additions and 1313 deletions

32
CMakeLists.txt Normal file
View File

@@ -0,0 +1,32 @@
cmake_minimum_required(VERSION 3.26)
project(mscclpp LANGUAGES CUDA CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules)
find_package(CUDAToolkit REQUIRED)
find_package(IBVerbs REQUIRED)
find_package(NUMA REQUIRED)
find_package(GDRCopy)
option(USE_MPI_FOR_TESTS "Use MPI for tests" ON)
if(USE_MPI_FOR_TESTS)
find_package(MPI REQUIRED)
add_definitions(-DMSCCLPP_USE_MPI_FOR_TESTS)
endif()
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
add_library(mscclpp SHARED)
add_subdirectory(src) # This adds the srouces to the mscclpp target
target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src/include)
set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX)
target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart CUDA::cuda_driver)
if(GDRCOPY_FOUND)
target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy)
endif()
add_subdirectory(tests)

View File

@@ -61,7 +61,7 @@ endif
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xfatbin -compress-all
# Use addprefix so that we can specify more than one path
NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt
NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt -lcuda
ifeq ($(DEBUG), 0)
NVCUFLAGS += -O3
@@ -120,7 +120,8 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc)
ifneq ($(NPKIT), 0)
LIBSRCS += $(addprefix src/misc/,npkit.cc)
endif
@@ -134,7 +135,7 @@ HEADERS := $(wildcard src/include/*.h)
CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp
INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%)
LIBNAME := libmscclpp.so
@@ -148,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
TESTSDIR := tests
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu bootstrap_test_cpp.cc)
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu communicator_test_cpp.cu bootstrap_test_cpp.cc allgather_test_cpp.cu)
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))

8
TODO.md Normal file
View File

@@ -0,0 +1,8 @@
# Core API extraction
- Add a test for host side Communicator/RegisteredMemory/Connection use.
- Implement a standalone "epoch" synchronization construct that can be used as a component in custom proxies. epoch.hpp/cc has the beginnings of this.
- Reimplement the "standard" proxy service + DeviceConnection on top of the new Communicator/RegisteredMemory/Connection core API. Remants of the old code is in channel.hpp, basic_proxy_handler.hpp/cc and host_connection.hpp/cc. Probably need a manager class to wrap all of this.
- Change the new IBConnection and Communicator to use the new C++ IbCtx and IbQp classes.
- Implement IbQp::~IbQp()
- Fix RegisteredMemory::Impl::Impl to get the IPC handle from the base pointer, not the derived pointer.

View File

@@ -0,0 +1,41 @@
# Find the GDRCopy libraries
#
# The following variables are optionally searched for defaults
# GDRCOPY_ROOT_DIR: Base directory where all GDRCopy components are found
# GDRCOPY_INCLUDE_DIR: Directory where GDRCopy headers are found
# GDRCOPY_LIB_DIR: Directory where GDRCopy libraries are found
# The following are set after configuration is done:
# GDRCOPY_FOUND
# GDRCOPY_INCLUDE_DIRS
# GDRCOPY_LIBRARIES
# An imported target MSCCLPP::gdrcopy is created if the library is found.
find_path(GDRCOPY_INCLUDE_DIRS
NAMES gdrapi.h
HINTS
${GDRCOPY_INCLUDE_DIR}
${GDRCOPY_ROOT_DIR}
${GDRCOPY_ROOT_DIR}/include)
find_library(GDRCOPY_LIBRARIES
NAMES gdrapi
HINTS
${GDRCOPY_LIB_DIR}
${GDRCOPY_ROOT_DIR}
${GDRCOPY_ROOT_DIR}/lib)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(GDRCopy DEFAULT_MSG GDRCOPY_INCLUDE_DIRS GDRCOPY_LIBRARIES)
mark_as_advanced(GDRCOPY_INCLUDE_DIR GDRCOPY_LIBRARIES)
if(GDRCOPY_FOUND)
if(NOT TARGET MSCCLPP::gdrcopy)
add_library(MSCCLPP::gdrcopy UNKNOWN IMPORTED)
endif()
set_target_properties(MSCCLPP::gdrcopy PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${GDRCOPY_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
IMPORTED_LOCATION "${GDRCOPY_LIBRARIES}")
endif()

View File

@@ -0,0 +1,41 @@
# Find the IB Verbs libraries
#
# The following variables are optionally searched for defaults
# IBVERBS_ROOT_DIR: Base directory where all ibverbs components are found
# IBVERBS_INCLUDE_DIR: Directory where ibverbs headers are found
# IBVERBS_LIB_DIR: Directory where ibverbs libraries are found
# The following are set after configuration is done:
# IBVERBS_FOUND
# IBVERBS_INCLUDE_DIRS
# IBVERBS_LIBRARIES
# An imported target MSCCLPP::ibverbs is created if the library is found.
find_path(IBVERBS_INCLUDE_DIRS
NAMES infiniband/verbs.h
HINTS
${IBVERBS_INCLUDE_DIR}
${IBVERBS_ROOT_DIR}
${IBVERBS_ROOT_DIR}/include)
find_library(IBVERBS_LIBRARIES
NAMES ibverbs
HINTS
${IBVERBS_LIB_DIR}
${IBVERBS_ROOT_DIR}
${IBVERBS_ROOT_DIR}/lib)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(IBVerbs DEFAULT_MSG IBVERBS_INCLUDE_DIRS IBVERBS_LIBRARIES)
mark_as_advanced(IBVERBS_INCLUDE_DIR IBVERBS_LIBRARIES)
if(IBVERBS_FOUND)
if(NOT TARGET MSCCLPP::ibverbs)
add_library(MSCCLPP::ibverbs UNKNOWN IMPORTED)
endif()
set_target_properties(MSCCLPP::ibverbs PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${IBVERBS_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
IMPORTED_LOCATION "${IBVERBS_LIBRARIES}")
endif()

View File

@@ -0,0 +1,41 @@
# Find the numa libraries
#
# The following variables are optionally searched for defaults
# NUMA_ROOT_DIR: Base directory where all numa components are found
# NUMA_INCLUDE_DIR: Directory where numa headers are found
# NUMA_LIB_DIR: Directory where numa libraries are found
# The following are set after configuration is done:
# NUMA_FOUND
# NUMA_INCLUDE_DIRS
# NUMA_LIBRARIES
# An imported target MSCCLPP::numa is created if the library is found.
find_path(NUMA_INCLUDE_DIRS
NAMES numa.h
HINTS
${NUMA_INCLUDE_DIR}
${NUMA_ROOT_DIR}
${NUMA_ROOT_DIR}/include)
find_library(NUMA_LIBRARIES
NAMES numa
HINTS
${NUMA_LIB_DIR}
${NUMA_ROOT_DIR}
${NUMA_ROOT_DIR}/lib)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NUMA DEFAULT_MSG NUMA_INCLUDE_DIRS NUMA_LIBRARIES)
mark_as_advanced(NUMA_INCLUDE_DIR NUMA_LIBRARIES)
if(NUMA_FOUND)
if(NOT TARGET MSCCLPP::numa)
add_library(MSCCLPP::numa UNKNOWN IMPORTED)
endif()
set_target_properties(MSCCLPP::numa PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
IMPORTED_LOCATION "${NUMA_LIBRARIES}")
endif()

5
src/CMakeLists.txt Normal file
View File

@@ -0,0 +1,5 @@
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.h)
file(GLOB to_remove gdr.cc)
list(REMOVE_ITEM SOURCES ${to_remove})
target_sources(mscclpp PRIVATE ${SOURCES})

View File

@@ -1,29 +0,0 @@
#include "basic_proxy_handler.hpp"
namespace mscclpp {
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) {
return [&comm](ProxyTrigger triggerRaw) {
ChannelTrigger *trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
HostConnection& conn = *comm.connections.at(trigger->fields.connId);
auto result = ProxyHandlerResult::Continue;
if (trigger->fields.type & mscclppData) {
conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, trigger->fields.srcOffset, trigger->fields.size);
}
if (trigger->fields.type & mscclppFlag) {
conn.signal();
}
if (trigger->fields.type & mscclppSync) {
conn.flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
}
return result;
};
}
} // namespace mscclpp

View File

@@ -180,9 +180,8 @@ Bootstrap::Impl::~Impl()
}
}
void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank)
void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank)
{
mscclppSocket sock;
ExtInfo info;
@@ -211,7 +210,7 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
}
void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot)
const std::vector<mscclppSocketAddress>& rankAddressesRoot)
{
mscclppSocket sock;
int next = (peer + 1) % this->nRanks_;
@@ -226,7 +225,8 @@ void Bootstrap::Impl::bootstrapCreateRoot()
mscclppSocket listenSock;
// mscclppSocket* listenSock = new mscclppSocket(); // TODO(saemal) make this a shared ptr
MSCCLPPTHROW(mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
MSCCLPPTHROW(
mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
MSCCLPPTHROW(mscclppSocketListen(&listenSock));
MSCCLPPTHROW(mscclppSocketGetAddr(&listenSock, &uniqueId_.addr));
auto lambda = [this, listenSock]() { this->bootstrapRoot(listenSock); };

26
src/channel.cc Normal file
View File

@@ -0,0 +1,26 @@
#include "channel.hpp"
#include "utils.h"
#include "checks.hpp"
#include "api.h"
#include "debug.h"
namespace mscclpp {
namespace channel {
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator),
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
int cudaDevice;
CUDATHROW(cudaGetDevice(&cudaDevice));
MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode));
}
MSCCLPP_API_CPP void DeviceChannelService::bindThread()
{
if (deviceNumaNode >= 0) {
MSCCLPPTHROW(numaBind(deviceNumaNode));
INFO(MSCCLPP_INIT, "NUMA node of DeviceChannelService proxy thread is set to %d", deviceNumaNode);
}
}
} // namespace channel
} // namespace mscclpp

View File

@@ -1,81 +1,152 @@
#include "communicator.hpp"
#include "host_connection.hpp"
#include "comm.h"
#include "basic_proxy_handler.hpp"
#include <sstream>
#include "api.h"
#include "checks.hpp"
#include "comm.h"
#include "communicator.hpp"
#include "connection.hpp"
#include "debug.h"
#include "mscclpp.hpp"
#include "registered_memory.hpp"
#include "utils.h"
namespace mscclpp {
Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {}
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap)
{
rankToHash_.resize(bootstrap->getNranks());
auto hostHash = getHostHash();
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
rankToHash_[bootstrap->getRank()] = hostHash;
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
}
Communicator::Impl::~Impl() {
if (comm) {
mscclppCommDestroy(comm);
Communicator::Impl::~Impl()
{
ibContexts_.clear();
}
IbCtx* Communicator::Impl::getIbContext(Transport ibTransport)
{
// Find IB context or create it
auto it = ibContexts_.find(ibTransport);
if (it == ibContexts_.end()) {
auto ibDev = getIBDeviceName(ibTransport);
ibContexts_[ibTransport] = std::make_unique<IbCtx>(ibDev);
return ibContexts_[ibTransport].get();
} else {
return it->second.get();
}
}
MSCCLPP_API_CPP Communicator::~Communicator() = default;
mscclppTransport_t transportTypeToCStyle(TransportType type) {
switch (type) {
case TransportType::IB:
return mscclppTransportIB;
case TransportType::P2P:
return mscclppTransportP2P;
default:
throw std::runtime_error("Unknown transport type");
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap)
: pimpl(std::make_unique<Impl>(bootstrap))
{
}
MSCCLPP_API_CPP std::shared_ptr<BaseBootstrap> Communicator::bootstrapper()
{
return pimpl->bootstrap_;
}
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports)
{
return RegisteredMemory(
std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl));
}
struct MemorySender : public Setuppable
{
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
{
bootstrap->send(memory_.serialize(), remoteRank_, tag_);
}
RegisteredMemory memory_;
int remoteRank_;
int tag_;
};
MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag)
{
addSetup(std::make_shared<MemorySender>(memory, remoteRank, tag));
}
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
struct MemoryReceiver : public Setuppable
{
MemoryReceiver(int remoteRank, int tag)
: remoteRank_(remoteRank), tag_(tag) {}
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
{
std::vector<char> data;
bootstrap->recv(data, remoteRank_, tag_);
memoryPromise_.set_value(RegisteredMemory::deserialize(data));
}
std::promise<RegisteredMemory> memoryPromise_;
int remoteRank_;
int tag_;
};
MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSetup(int remoteRank, int tag)
{
auto memoryReceiver = std::make_shared<MemoryReceiver>(remoteRank, tag);
addSetup(memoryReceiver);
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
}
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
}
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
mscclppBootstrapAllGather(pimpl->comm, data, size);
}
MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
mscclppBootstrapBarrier(pimpl->comm);
}
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
TransportType transportType, const char* ibDev) {
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
auto connIdx = pimpl->connections.size();
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
pimpl->connections.push_back(conn);
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport)
{
std::shared_ptr<ConnectionBase> conn;
if (transport == Transport::CudaIpc) {
// sanity check: make sure the IPC connection is being made within a node
if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) {
std::stringstream ss;
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"
<< " != " << pimpl->bootstrap_->getRank() << "(" << std::hex
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
throw std::runtime_error(ss.str());
}
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag);
conn = cudaIpcConn;
INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created",
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank,
pimpl->rankToHash_[remoteRank]);
} else if (AllIBTransports.has(transport)) {
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
conn = ibConn;
INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created",
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
} else {
throw std::runtime_error("Unsupported transport");
}
pimpl->connections_.push_back(conn);
addSetup(conn);
return conn;
}
MSCCLPP_API_CPP void Communicator::connectionSetup() {
mscclppConnectionSetup(pimpl->comm);
MSCCLPP_API_CPP void Communicator::addSetup(std::shared_ptr<Setuppable> setuppable)
{
pimpl->toSetup_.push_back(setuppable);
}
MSCCLPP_API_CPP void Communicator::startProxying() {
pimpl->proxy.start();
MSCCLPP_API_CPP void Communicator::setup()
{
for (auto& setuppable : pimpl->toSetup_) {
setuppable->beginSetup(pimpl->bootstrap_);
}
for (auto& setuppable : pimpl->toSetup_) {
setuppable->endSetup(pimpl->bootstrap_);
}
pimpl->toSetup_.clear();
}
MSCCLPP_API_CPP void Communicator::stopProxying() {
pimpl->proxy.stop();
}
MSCCLPP_API_CPP int Communicator::rank() {
int result;
mscclppCommRank(pimpl->comm, &result);
return result;
}
MSCCLPP_API_CPP int Communicator::size() {
int result;
mscclppCommSize(pimpl->comm, &result);
return result;
}
} // namespace mscclpp
} // namespace mscclpp

165
src/connection.cc Normal file
View File

@@ -0,0 +1,165 @@
#include <algorithm>
#include "connection.hpp"
#include "checks.hpp"
#include "infiniband/verbs.h"
#include "npkit/npkit.h"
#include "registered_memory.hpp"
#include "utils.hpp"
namespace mscclpp {
void validateTransport(RegisteredMemory mem, Transport transport)
{
if (!mem.transports().has(transport)) {
throw std::runtime_error("mem does not support transport");
}
}
// Connection
std::shared_ptr<RegisteredMemory::Impl> Connection::getRegisteredMemoryImpl(RegisteredMemory& mem)
{
return mem.pimpl;
}
// ConnectionBase
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
int ConnectionBase::remoteRank() { return remoteRank_; }
int ConnectionBase::tag() { return tag_; }
// CudaIpcConnection
CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag)
{
CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
}
CudaIpcConnection::~CudaIpcConnection()
{
cudaStreamDestroy(stream);
}
Transport CudaIpcConnection::transport()
{
return Transport::CudaIpc;
}
Transport CudaIpcConnection::remoteTransport()
{
return Transport::CudaIpc;
}
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size)
{
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
char* dstPtr = (char*)dst.data();
char* srcPtr = (char*)src.data();
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream));
INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size);
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
}
void CudaIpcConnection::flush()
{
CUDATHROW(cudaStreamSynchronize(stream));
// npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
}
// IBConnection
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl)
: ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0)
{
qp = commImpl.getIbContext(transport)->createQp();
}
Transport IBConnection::transport()
{
return transport_;
}
Transport IBConnection::remoteTransport()
{
return remoteTransport_;
}
void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size)
{
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
if (dstTransportInfo.ibLocal) {
throw std::runtime_error("dst is local, which is not supported");
}
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport());
if (!srcTransportInfo.ibLocal) {
throw std::runtime_error("src is remote, which is not supported");
}
auto dstMrInfo = dstTransportInfo.ibMrInfo;
auto srcMr = srcTransportInfo.ibMr;
qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset,
/*signaled=*/true);
numSignaledSends++;
qp->postSend();
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, (uint8_t*)dstMrInfo.addr + dstOffset, size);
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
}
void IBConnection::flush()
{
while (numSignaledSends) {
int wcNum = qp->pollCq();
if (wcNum < 0) {
throw std::runtime_error("pollCq failed: error no " + std::to_string(errno));
}
for (int i = 0; i < wcNum; ++i) {
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
if (wc->status != IBV_WC_SUCCESS) {
throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status));
}
if (wc->opcode == IBV_WC_RDMA_WRITE) {
numSignaledSends--;
}
}
}
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}
void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap)
{
std::vector<char> ibQpTransport;
std::copy_n(reinterpret_cast<char*>(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport));
std::copy_n(reinterpret_cast<char*>(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport));
bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
}
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap)
{
std::vector<char> ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport));
bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
IbQpInfo qpInfo;
auto it = ibQpTransport.begin();
std::copy_n(it, sizeof(qpInfo), reinterpret_cast<char*>(&qpInfo));
it += sizeof(qpInfo);
std::copy_n(it, sizeof(remoteTransport_), reinterpret_cast<char*>(&remoteTransport_));
it += sizeof(qpInfo);
qp->rtr(qpInfo);
qp->rts();
}
} // namespace mscclpp

26
src/epoch.cc Normal file
View File

@@ -0,0 +1,26 @@
#include "epoch.hpp"
#include "checks.hpp"
#include "alloc.h"
#include "api.h"
namespace mscclpp {
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection) : connection_(connection) {
MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1));
MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1));
localEpochIdsRegMem_ = communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport());
communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection->remoteRank(), connection->tag());
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
}
MSCCLPP_API_CPP Epoch::~Epoch() {
mscclppCudaFree(device_.epochIds_);
mscclppCudaFree(device_.expectedInboundEpochId_);
}
MSCCLPP_API_CPP void Epoch::signal() {
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
}
} // namespace mscclpp

View File

@@ -1,13 +1,15 @@
#include "mscclppfifo.hpp"
#include "alloc.h"
#include "checks.hpp"
#include "mscclppfifo.hpp"
#include "api.h"
#include <cuda_runtime.h>
#include <stdexcept>
#include <emmintrin.h>
#include <stdexcept>
namespace mscclpp {
struct HostProxyFifo::Impl {
struct HostProxyFifo::Impl
{
DeviceProxyFifo deviceFifo;
// allocated on the host. Only accessed by the host. This is a copy of the
@@ -23,7 +25,8 @@ struct HostProxyFifo::Impl {
cudaStream_t stream;
};
HostProxyFifo::HostProxyFifo() {
MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo()
{
pimpl = std::make_unique<Impl>();
MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1));
MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE));
@@ -32,35 +35,40 @@ HostProxyFifo::HostProxyFifo() {
pimpl->hostTail = 0;
}
HostProxyFifo::~HostProxyFifo() {
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.head));
MSCCLPPTHROW(mscclppCudaHostFree(pimpl->deviceFifo.triggers));
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.tailReplica));
CUDATHROW(cudaStreamDestroy(pimpl->stream));
MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo()
{
mscclppCudaFree(pimpl->deviceFifo.head);
mscclppCudaHostFree(pimpl->deviceFifo.triggers);
mscclppCudaFree(pimpl->deviceFifo.tailReplica);
cudaStreamDestroy(pimpl->stream);
}
void HostProxyFifo::poll(ProxyTrigger *trigger) {
MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger)
{
__m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]);
_mm_store_si128((__m128i*)trigger, xmm0);
}
void HostProxyFifo::pop() {
MSCCLPP_API_CPP void HostProxyFifo::pop()
{
*(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
(pimpl->hostTail)++;
}
void HostProxyFifo::flushTail(bool sync) {
MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync)
{
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
// request.
CUDATHROW(
cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream));
CUDATHROW(cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice,
pimpl->stream));
if (sync) {
CUDATHROW(cudaStreamSynchronize(pimpl->stream));
}
}
DeviceProxyFifo HostProxyFifo::toDevice() {
MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo()
{
return pimpl->deviceFifo;
}

View File

@@ -1,79 +0,0 @@
#include "host_connection.hpp"
#include "communicator.hpp"
#include "comm.h"
#include "mscclpp.h"
#include "mscclppfifo.h"
#include "api.h"
namespace mscclpp {
HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) {
this->hostConn = conn->hostConn;
}
HostConnection::Impl::~Impl() {
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
}
MSCCLPP_API_CPP HostConnection::~HostConnection() = default;
MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr<Impl> p) : pimpl(std::move(p)) {}
MSCCLPP_API_CPP int HostConnection::getId() {
return pimpl->conn->connId;
}
MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
BufferHandle result;
static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t));
mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, reinterpret_cast<mscclppBufferHandle_t*>(&result));
return result;
}
MSCCLPP_API_CPP int HostConnection::numLocalBuffers() {
return pimpl->conn->bufferRegistrations.size() - 1;
}
MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) {
return index + 1;
}
MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() {
return pimpl->conn->remoteBufferRegistrations.size() - 1;
}
MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) {
return index + 1;
}
MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() {
ConnectionEpoch epoch;
static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId));
epoch.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
epoch.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
epoch.waitEpochId = pimpl->conn->devConn->waitEpochId;
return epoch;
}
MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() {
return pimpl->comm->pimpl->proxy.fifo().toDevice();
}
MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
pimpl->hostConn->put(dst, dstOffset, src, srcOffset, size);
}
MSCCLPP_API_CPP void HostConnection::signal() {
pimpl->hostConn->signal();
}
MSCCLPP_API_CPP void HostConnection::flush() {
pimpl->hostConn->flush();
}
MSCCLPP_API_CPP void HostConnection::wait() {
pimpl->hostConn->wait();
}
} // namespace mscclpp

665
src/ib.cc
View File

@@ -2,323 +2,182 @@
#include <cstdlib>
#include <cstring>
#include <malloc.h>
#include <sstream>
#include <unistd.h>
#include <vector>
#include "alloc.h"
#include "checks.hpp"
#include "comm.h"
#include "debug.h"
#include "ib.h"
#include "ib.hpp"
#include "mscclpp.hpp"
#include "api.h"
#include <infiniband/verbs.h>
#include <string>
static int getIbDevNumaNode(const char* ibDevPath)
{
if (ibDevPath == NULL) {
WARN("ibDevPath is NULL");
return -1;
}
const char* postfix = "/device/numa_node";
FILE* fp = NULL;
char* filePath = NULL;
int node = -1;
int res;
if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) {
WARN("mscclppCalloc failed");
goto exit;
}
memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char));
filePath[strlen(ibDevPath)] = '\0';
if (strncat(filePath, postfix, strlen(postfix)) == NULL) {
WARN("strncat failed");
goto exit;
}
fp = fopen(filePath, "r");
if (fp == NULL) {
WARN("fopen failed (errno %d, path %s)", errno, filePath);
goto exit;
}
res = fscanf(fp, "%d", &node);
if (res != 1) {
WARN("fscanf failed (errno %d, path %s)", errno, filePath);
node = -1;
goto exit;
}
exit:
if (filePath != NULL) {
free(filePath);
}
if (fp != NULL) {
fclose(fp);
}
return node;
}
namespace mscclpp {
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName)
{
struct mscclppIbContext* _ctx;
MSCCLPPCHECK(mscclppCalloc(&_ctx, 1));
std::vector<int> ports;
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (strncmp(devices[i]->name, ibDevName, IBV_SYSFS_NAME_MAX) == 0) {
_ctx->ctx = ibv_open_device(devices[i]);
break;
}
}
ibv_free_device_list(devices);
if (_ctx->ctx == nullptr) {
WARN("ibv_open_device failed (errno %d, device name %s)", errno, ibDevName);
goto fail;
}
// Check available ports
struct ibv_device_attr devAttr;
if (ibv_query_device(_ctx->ctx, &devAttr) != 0) {
WARN("ibv_query_device failed (errno %d, device name %s)", errno, ibDevName);
goto fail;
}
for (uint8_t i = 1; i <= devAttr.phys_port_cnt; ++i) {
struct ibv_port_attr portAttr;
if (ibv_query_port(_ctx->ctx, i, &portAttr) != 0) {
WARN("ibv_query_port failed (errno %d, port %d)", errno, i);
goto fail;
}
if (portAttr.state != IBV_PORT_ACTIVE) {
continue;
}
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) {
continue;
}
ports.push_back((int)i);
}
if (ports.size() == 0) {
WARN("no active IB port found");
goto fail;
}
MSCCLPPCHECK(mscclppCalloc(&_ctx->ports, ports.size()));
_ctx->nPorts = (int)ports.size();
for (int i = 0; i < _ctx->nPorts; ++i) {
_ctx->ports[i] = ports[i];
}
_ctx->pd = ibv_alloc_pd(_ctx->ctx);
if (_ctx->pd == NULL) {
WARN("ibv_alloc_pd failed (errno %d)", errno);
goto fail;
}
*ctx = _ctx;
return mscclppSuccess;
fail:
*ctx = NULL;
if (_ctx->ports != NULL) {
free(_ctx->ports);
}
free(_ctx);
return mscclppInternalError;
}
mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx)
{
for (int i = 0; i < ctx->nMrs; ++i) {
if (ctx->mrs[i].mr) {
ibv_dereg_mr(ctx->mrs[i].mr);
}
}
for (int i = 0; i < ctx->nQps; ++i) {
if (ctx->qps[i].qp) {
ibv_destroy_qp(ctx->qps[i].qp);
}
ibv_destroy_cq(ctx->qps[i].cq);
free(ctx->qps[i].wcs);
free(ctx->qps[i].sges);
free(ctx->qps[i].wrs);
}
if (ctx->pd != NULL) {
ibv_dealloc_pd(ctx->pd);
}
if (ctx->ctx != NULL) {
ibv_close_device(ctx->ctx);
}
free(ctx->mrs);
free(ctx->qps);
free(ctx->ports);
free(ctx);
return mscclppSuccess;
}
mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port /*=-1*/)
{
if (port < 0) {
port = ctx->ports[0];
} else {
bool found = false;
for (int i = 0; i < ctx->nPorts; ++i) {
if (ctx->ports[i] == port) {
found = true;
break;
}
}
if (!found) {
WARN("invalid IB port: %d", port);
return mscclppInternalError;
}
}
struct ibv_cq* cq = ibv_create_cq(ctx->ctx, MSCCLPP_IB_CQ_SIZE, NULL, NULL, 0);
if (cq == NULL) {
WARN("ibv_create_cq failed (errno %d)", errno);
return mscclppInternalError;
}
struct ibv_qp_init_attr qp_init_attr;
std::memset(&qp_init_attr, 0, sizeof(struct ibv_qp_init_attr));
qp_init_attr.sq_sig_all = 0;
qp_init_attr.send_cq = cq;
qp_init_attr.recv_cq = cq;
qp_init_attr.qp_type = IBV_QPT_RC;
qp_init_attr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qp_init_attr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qp_init_attr.cap.max_send_sge = 1;
qp_init_attr.cap.max_recv_sge = 1;
qp_init_attr.cap.max_inline_data = 0;
struct ibv_qp* qp = ibv_create_qp(ctx->pd, &qp_init_attr);
if (qp == nullptr) {
WARN("ibv_create_qp failed (errno %d)", errno);
return mscclppInternalError;
}
struct ibv_port_attr port_attr;
if (ibv_query_port(ctx->ctx, port, &port_attr) != 0) {
WARN("ibv_query_port failed (errno %d, port %d)", errno, port);
return mscclppInternalError;
}
// Register QP to this ctx
qp->context = ctx->ctx;
if (qp->context == NULL) {
WARN("IB context is NULL");
return mscclppInternalError;
}
ctx->nQps++;
if (ctx->qps == NULL) {
MSCCLPPCHECK(mscclppCalloc(&ctx->qps, MAXCONNECTIONS));
ctx->maxQps = MAXCONNECTIONS;
}
if (ctx->maxQps < ctx->nQps) {
WARN("too many QPs");
return mscclppInternalError;
}
struct mscclppIbQp* _ibQp = &ctx->qps[ctx->nQps - 1];
_ibQp->qp = qp;
_ibQp->info.lid = port_attr.lid;
_ibQp->info.port = port;
_ibQp->info.linkLayer = port_attr.link_layer;
_ibQp->info.qpn = qp->qp_num;
_ibQp->info.mtu = port_attr.active_mtu;
if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND) {
union ibv_gid gid;
if (ibv_query_gid(ctx->ctx, port, 0, &gid) != 0) {
WARN("ibv_query_gid failed (errno %d)", errno);
return mscclppInternalError;
}
_ibQp->info.spn = gid.global.subnet_prefix;
}
struct ibv_qp_attr qp_attr;
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
qp_attr.qp_state = IBV_QPS_INIT;
qp_attr.pkey_index = 0;
qp_attr.port_num = port;
qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
if (ibv_modify_qp(qp, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
WARN("ibv_modify_qp failed (errno %d)", errno);
return mscclppInternalError;
}
MSCCLPPCHECK(mscclppCalloc(&_ibQp->wrs, MSCCLPP_IB_MAX_SENDS));
MSCCLPPCHECK(mscclppCalloc(&_ibQp->sges, MSCCLPP_IB_MAX_SENDS));
MSCCLPPCHECK(mscclppCalloc(&_ibQp->wcs, MSCCLPP_IB_CQ_POLL_NUM));
_ibQp->cq = cq;
*ibQp = _ibQp;
return mscclppSuccess;
}
mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size,
struct mscclppIbMr** ibMr)
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
{
if (size == 0) {
WARN("invalid size: %zu", size);
return mscclppInvalidArgument;
throw std::runtime_error("invalid size: " + std::to_string(size));
}
static __thread uintptr_t pageSize = 0;
if (pageSize == 0) {
pageSize = sysconf(_SC_PAGESIZE);
}
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
struct ibv_mr* mr =
ibv_reg_mr(ctx->pd, reinterpret_cast<void*>(addr), pages * pageSize,
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
struct ibv_mr* _mr =
ibv_reg_mr(_pd, reinterpret_cast<void*>(addr), pages * pageSize,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
if (mr == nullptr) {
WARN("ibv_reg_mr failed (errno %d)", errno);
return mscclppInternalError;
if (_mr == nullptr) {
std::stringstream err;
err << "ibv_reg_mr failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
ctx->nMrs++;
if (ctx->mrs == NULL) {
MSCCLPPCHECK(mscclppCalloc(&ctx->mrs, MAXCONNECTIONS));
ctx->maxMrs = MAXCONNECTIONS;
}
if (ctx->maxMrs < ctx->nMrs) {
WARN("too many MRs");
return mscclppInternalError;
}
struct mscclppIbMr* _ibMr = &ctx->mrs[ctx->nMrs - 1];
_ibMr->mr = mr;
_ibMr->buff = buff;
_ibMr->info.addr = (uint64_t)buff;
_ibMr->info.rkey = mr->rkey;
*ibMr = _ibMr;
return mscclppSuccess;
this->mr = _mr;
this->size = pages * pageSize;
}
//////////////////////////////////////////////////////////////////////////////
IbMr::~IbMr()
{
ibv_dereg_mr(reinterpret_cast<struct ibv_mr*>(this->mr));
}
int mscclppIbQp::rtr(const mscclppIbQpInfo* info)
IbMrInfo IbMr::getInfo() const
{
IbMrInfo info;
info.addr = reinterpret_cast<uint64_t>(this->buff);
info.rkey = reinterpret_cast<struct ibv_mr*>(this->mr)->rkey;
return info;
}
const void* IbMr::getBuff() const
{
return this->buff;
}
uint32_t IbMr::getLkey() const
{
return reinterpret_cast<struct ibv_mr*>(this->mr)->lkey;
}
IbQp::IbQp(void* ctx, void* pd, int port)
{
struct ibv_context* _ctx = reinterpret_cast<struct ibv_context*>(ctx);
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
if (this->cq == nullptr) {
std::stringstream err;
err << "ibv_create_cq failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
struct ibv_qp_init_attr qpInitAttr;
std::memset(&qpInitAttr, 0, sizeof(qpInitAttr));
qpInitAttr.sq_sig_all = 0;
qpInitAttr.send_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
qpInitAttr.recv_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
qpInitAttr.qp_type = IBV_QPT_RC;
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
qpInitAttr.cap.max_send_sge = 1;
qpInitAttr.cap.max_recv_sge = 1;
qpInitAttr.cap.max_inline_data = 0;
struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr);
if (_qp == nullptr) {
std::stringstream err;
err << "ibv_create_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
struct ibv_port_attr portAttr;
if (ibv_query_port(_ctx, port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
this->info.lid = portAttr.lid;
this->info.port = port;
this->info.linkLayer = portAttr.link_layer;
this->info.qpn = _qp->qp_num;
this->info.mtu = portAttr.active_mtu;
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) {
union ibv_gid gid;
if (ibv_query_gid(_ctx, port, 0, &gid) != 0) {
std::stringstream err;
err << "ibv_query_gid failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
this->info.spn = gid.global.subnet_prefix;
}
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(qpAttr));
qpAttr.qp_state = IBV_QPS_INIT;
qpAttr.pkey_index = 0;
qpAttr.port_num = port;
qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
this->qp = _qp;
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_send_wr**>(&this->wrs), MSCCLPP_IB_MAX_SENDS));
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_sge**>(&this->sges), MSCCLPP_IB_MAX_SENDS));
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_wc**>(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM));
}
IbQp::~IbQp()
{
ibv_destroy_qp(reinterpret_cast<struct ibv_qp*>(this->qp));
ibv_destroy_cq(reinterpret_cast<struct ibv_cq*>(this->cq));
std::free(this->wrs);
std::free(this->sges);
std::free(this->wcs);
}
void IbQp::rtr(const IbQpInfo& info)
{
struct ibv_qp_attr qp_attr;
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
qp_attr.qp_state = IBV_QPS_RTR;
qp_attr.path_mtu = info->mtu;
qp_attr.dest_qp_num = info->qpn;
qp_attr.path_mtu = static_cast<ibv_mtu>(info.mtu);
qp_attr.dest_qp_num = info.qpn;
qp_attr.rq_psn = 0;
qp_attr.max_dest_rd_atomic = 1;
qp_attr.min_rnr_timer = 0x12;
if (info->linkLayer == IBV_LINK_LAYER_ETHERNET) {
if (info.linkLayer == IBV_LINK_LAYER_ETHERNET) {
qp_attr.ah_attr.is_global = 1;
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
qp_attr.ah_attr.grh.dgid.global.interface_id = info->lid;
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn;
qp_attr.ah_attr.grh.dgid.global.interface_id = info.lid;
qp_attr.ah_attr.grh.flow_label = 0;
qp_attr.ah_attr.grh.sgid_index = 0;
qp_attr.ah_attr.grh.hop_limit = 255;
qp_attr.ah_attr.grh.traffic_class = 0;
} else {
qp_attr.ah_attr.is_global = 0;
qp_attr.ah_attr.dlid = info->lid;
qp_attr.ah_attr.dlid = info.lid;
}
qp_attr.ah_attr.sl = 0;
qp_attr.ah_attr.src_path_bits = 0;
qp_attr.ah_attr.port_num = info->port;
return ibv_modify_qp(this->qp, &qp_attr,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
qp_attr.ah_attr.port_num = info.port;
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
if (ret != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
}
int mscclppIbQp::rts()
void IbQp::rts()
{
struct ibv_qp_attr qp_attr;
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
@@ -328,75 +187,267 @@ int mscclppIbQp::rts()
qp_attr.rnr_retry = 7;
qp_attr.sq_psn = 0;
qp_attr.max_rd_atomic = 1;
return ibv_modify_qp(this->qp, &qp_attr,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
IBV_QP_MAX_QP_RD_ATOMIC);
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
IBV_QP_MAX_QP_RD_ATOMIC);
if (ret != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
}
int mscclppIbQp::stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
uint64_t srcOffset, uint64_t dstOffset, bool signaled)
int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled)
{
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
return -1;
}
int wrn = this->wrn;
struct ibv_send_wr* wr_ = &this->wrs[wrn];
struct ibv_sge* sge_ = &this->sges[wrn];
// std::memset(wr_, 0, sizeof(struct ibv_send_wr));
// std::memset(sge_, 0, sizeof(struct ibv_sge));
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
struct ibv_sge* sges_ = reinterpret_cast<struct ibv_sge*>(this->sges);
struct ibv_send_wr* wr_ = &wrs_[wrn];
struct ibv_sge* sge_ = &sges_[wrn];
wr_->wr_id = wrId;
wr_->sg_list = sge_;
wr_->num_sge = 1;
wr_->opcode = IBV_WR_RDMA_WRITE;
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset;
wr_->wr.rdma.rkey = info->rkey;
wr_->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset;
wr_->wr.rdma.rkey = info.rkey;
wr_->next = nullptr;
sge_->addr = (uint64_t)(ibMr->buff) + srcOffset;
sge_->addr = (uint64_t)(mr->getBuff()) + srcOffset;
sge_->length = size;
sge_->lkey = ibMr->mr->lkey;
sge_->lkey = mr->getLkey();
if (wrn > 0) {
this->wrs[wrn - 1].next = wr_;
wrs_[wrn - 1].next = wr_;
}
this->wrn++;
return this->wrn;
}
int mscclppIbQp::stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData)
int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled, unsigned int immData)
{
int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled);
this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
this->wrs[wrn - 1].imm_data = immData;
int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled);
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wrs_[wrn - 1].imm_data = immData;
return wrn;
}
int mscclppIbQp::postSend()
void IbQp::postSend()
{
if (this->wrn == 0) {
return 0;
return;
}
struct ibv_send_wr* bad_wr;
int ret = ibv_post_send(this->qp, this->wrs, &bad_wr);
int ret = ibv_post_send(reinterpret_cast<struct ibv_qp*>(this->qp), reinterpret_cast<struct ibv_send_wr*>(this->wrs),
&bad_wr);
if (ret != 0) {
return ret;
std::stringstream err;
err << "ibv_post_send failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
this->wrn = 0;
return 0;
}
int mscclppIbQp::postRecv(uint64_t wrId)
void IbQp::postRecv(uint64_t wrId)
{
struct ibv_recv_wr wr, *bad_wr;
wr.wr_id = wrId;
wr.sg_list = nullptr;
wr.num_sge = 0;
wr.next = nullptr;
return ibv_post_recv(this->qp, &wr, &bad_wr);
int ret = ibv_post_recv(reinterpret_cast<struct ibv_qp*>(this->qp), &wr, &bad_wr);
if (ret != 0) {
std::stringstream err;
err << "ibv_post_recv failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
}
int mscclppIbQp::pollCq()
int IbQp::pollCq()
{
return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs);
return ibv_poll_cq(reinterpret_cast<struct ibv_cq*>(this->cq), MSCCLPP_IB_CQ_POLL_NUM,
reinterpret_cast<struct ibv_wc*>(this->wcs));
}
IbQpInfo& IbQp::getInfo()
{
return this->info;
}
const void* IbQp::getWc(int idx) const
{
return &reinterpret_cast<struct ibv_wc*>(this->wcs)[idx];
}
IbCtx::IbCtx(const std::string& devName) : devName(devName)
{
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (std::string(devices[i]->name) == devName) {
this->ctx = ibv_open_device(devices[i]);
break;
}
}
ibv_free_device_list(devices);
if (this->ctx == nullptr) {
std::stringstream err;
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
throw std::runtime_error(err.str());
}
this->pd = ibv_alloc_pd(reinterpret_cast<struct ibv_context*>(this->ctx));
if (this->pd == nullptr) {
std::stringstream err;
err << "ibv_alloc_pd failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
}
IbCtx::~IbCtx()
{
this->mrs.clear();
this->qps.clear();
if (this->pd != nullptr) {
ibv_dealloc_pd(reinterpret_cast<struct ibv_pd*>(this->pd));
}
if (this->ctx != nullptr) {
ibv_close_device(reinterpret_cast<struct ibv_context*>(this->ctx));
}
}
bool IbCtx::isPortUsable(int port) const
{
struct ibv_port_attr portAttr;
if (ibv_query_port(reinterpret_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
throw std::runtime_error(err.str());
}
return portAttr.state == IBV_PORT_ACTIVE &&
(portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
}
int IbCtx::getAnyActivePort() const
{
struct ibv_device_attr devAttr;
if (ibv_query_device(reinterpret_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw std::runtime_error(err.str());
}
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
if (this->isPortUsable(port)) {
return port;
}
}
return -1;
}
IbQp* IbCtx::createQp(int port /*=-1*/)
{
if (port == -1) {
port = this->getAnyActivePort();
if (port == -1) {
throw std::runtime_error("No active port found");
}
} else if (!this->isPortUsable(port)) {
throw std::runtime_error("invalid IB port: " + std::to_string(port));
}
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
return qps.back().get();
}
const IbMr* IbCtx::registerMr(void* buff, std::size_t size)
{
mrs.emplace_back(new IbMr(this->pd, buff, size));
return mrs.back().get();
}
const std::string& IbCtx::getDevName() const
{
return this->devName;
}
MSCCLPP_API_CPP int getIBDeviceCount()
{
int num;
ibv_get_device_list(&num);
return num;
}
MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport)
{
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
int ibTransportIndex;
switch (ibTransport) { // TODO: get rid of this ugly switch
case Transport::IB0:
ibTransportIndex = 0;
break;
case Transport::IB1:
ibTransportIndex = 1;
break;
case Transport::IB2:
ibTransportIndex = 2;
break;
case Transport::IB3:
ibTransportIndex = 3;
break;
case Transport::IB4:
ibTransportIndex = 4;
break;
case Transport::IB5:
ibTransportIndex = 5;
break;
case Transport::IB6:
ibTransportIndex = 6;
break;
case Transport::IB7:
ibTransportIndex = 7;
break;
default:
throw std::runtime_error("Not an IB transport");
}
if (ibTransportIndex >= num) {
throw std::runtime_error("IB transport out of range");
}
return devices[ibTransportIndex]->name;
}
MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName)
{
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (ibDeviceName == devices[i]->name) {
switch (i) { // TODO: get rid of this ugly switch
case 0:
return Transport::IB0;
case 1:
return Transport::IB1;
case 2:
return Transport::IB2;
case 3:
return Transport::IB3;
case 4:
return Transport::IB4;
case 5:
return Transport::IB5;
case 6:
return Transport::IB6;
case 7:
return Transport::IB7;
default:
throw std::runtime_error("IB device index out of range");
}
}
}
throw std::runtime_error("IB device not found");
}
} // namespace mscclpp

View File

@@ -1,12 +1,12 @@
#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_
#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_
#include "mscclpp.hpp"
#include "communicator.hpp"
#include "mscclpp.hpp"
namespace mscclpp {
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm);
ProxyHandler makeBasicProxyHandler(Communicator::Impl& comm);
}

305
src/include/channel.hpp Normal file
View File

@@ -0,0 +1,305 @@
#ifndef MSCCLPP_CHANNEL_HPP_
#define MSCCLPP_CHANNEL_HPP_
#include "epoch.hpp"
#include "mscclpp.hpp"
#include "proxy.hpp"
#include "mscclppfifo.hpp"
#include "utils.hpp"
namespace mscclpp {
namespace channel {
// A Channel pairs a Connection with an Epoch
class Channel
{
public:
Channel(Communicator& communicator, std::shared_ptr<Connection> connection)
: connection_(connection), epoch_(std::make_shared<Epoch>(communicator, connection)) {};
Connection& connection() { return *connection_; }
Epoch& epoch() { return *epoch_; }
private:
std::shared_ptr<Connection> connection_;
std::shared_ptr<Epoch> epoch_;
};
using ChannelId = uint32_t;
using TriggerType = uint64_t;
const TriggerType TriggerData = 0x1;
const TriggerType TriggerFlag = 0x2;
const TriggerType TriggerSync = 0x4;
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
// mapping to the actual
using MemoryId = uint32_t;
#define MSCCLPP_BITS_SIZE 32
#define MSCCLPP_BITS_OFFSET 32
#define MSCCLPP_BITS_REGMEM_HANDLE 8
#define MSCCLPP_BITS_TYPE 3
#define MSCCLPP_BITS_CONNID 10
// this is the basic structure of each work element in the fifo
// the summation of number of bits must be 128 or less
union ChannelTrigger {
ProxyTrigger value;
struct
{
// first 64 bits: value[0]
uint64_t size : MSCCLPP_BITS_SIZE;
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
// second 64 bits: value[1]
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
uint64_t type : MSCCLPP_BITS_TYPE;
uint64_t chanId : MSCCLPP_BITS_CONNID;
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE -
MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
} fields;
#ifdef __CUDACC__
__device__ ChannelTrigger()
{
}
__device__ ChannelTrigger(ProxyTrigger value) : value(value)
{
}
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size, int connectionId)
{
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
<< MSCCLPP_BITS_REGMEM_HANDLE) +
src)
<< MSCCLPP_BITS_OFFSET) +
dstOffset);
}
#endif // __CUDACC__
};
struct DeviceChannel
{
DeviceChannel() = default;
DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo) : channelId_(channelId), epoch_(epoch), fifo_(fifo) {}
DeviceChannel(const DeviceChannel& other) = default;
DeviceChannel& operator=(DeviceChannel& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size)
{
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value);
}
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
{
put(dst, offset, src, offset, size);
}
__forceinline__ __device__ void signal()
{
epochIncrement();
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value);
}
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size)
{
epochIncrement();
fifo_.push(
ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_)
.value);
}
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
{
putWithSignal(dst, offset, src, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size)
{
epochIncrement();
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst,
dstOffset, src, srcOffset, size, channelId_)
.value);
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo_.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset,
uint64_t size)
{
putWithSignalAndFlush(dst, offset, src, offset, size);
}
__forceinline__ __device__ void flush()
{
uint64_t curFifoHead = fifo_.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, channelId_).value);
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo_.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void wait()
{
epoch_.wait();
}
__forceinline__ __device__ void epochIncrement()
{
epoch_.epochIncrement();
}
#endif // __CUDACC__
ChannelId channelId_;
DeviceEpoch epoch_;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
DeviceProxyFifo fifo_;
};
class DeviceChannelService;
inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService);
class DeviceChannelService {
public:
DeviceChannelService(Communicator& communicator);
ChannelId addChannel(std::shared_ptr<Connection> connection) {
channels_.push_back(Channel(communicator_, connection));
return channels_.size() - 1;
}
MemoryId addMemory(RegisteredMemory memory) {
memories_.push_back(memory);
return memories_.size() - 1;
}
Channel channel(ChannelId id) { return channels_[id]; }
DeviceChannel deviceChannel(ChannelId id) { return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo()); }
void startProxy() { proxy_.start(); }
void stopProxy() { proxy_.stop(); }
private:
Communicator& communicator_;
std::vector<Channel> channels_;
std::vector<RegisteredMemory> memories_;
Proxy proxy_;
int deviceNumaNode;
void bindThread();
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) {
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
Channel& channel = channels_[trigger->fields.chanId];
auto result = ProxyHandlerResult::Continue;
if (trigger->fields.type & TriggerData) {
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
channel.connection().write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, trigger->fields.size);
}
if (trigger->fields.type & TriggerFlag) {
channel.epoch().signal();
}
if (trigger->fields.type & TriggerSync) {
channel.connection().flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
}
return result;
}
};
struct SimpleDeviceChannel
{
SimpleDeviceChannel() = default;
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}
SimpleDeviceChannel(const SimpleDeviceChannel& other) = default;
SimpleDeviceChannel& operator=(SimpleDeviceChannel& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devChan_.put(dst_, dstOffset, src_, srcOffset, size);
}
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
{
put(offset, offset, size);
}
__forceinline__ __device__ void signal()
{
devChan_.signal();
}
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
}
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
{
putWithSignal(offset, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(offset, offset, size);
}
__forceinline__ __device__ void flush()
{
devChan_.flush();
}
__forceinline__ __device__ void wait()
{
devChan_.wait();
}
__forceinline__ __device__ void epochIncrement()
{
devChan_.epochIncrement();
}
#endif // __CUDACC__
DeviceChannel devChan_;
MemoryId dst_;
MemoryId src_;
};
} // namespace channel
} // namespace mscclpp
#endif // MSCCLPP_CHANNEL_HPP_

View File

@@ -8,6 +8,7 @@
#define MSCCLPP_CHECKS_HPP_
#include "debug.h"
#include <cuda.h>
#include <cuda_runtime.h>
#define MSCCLPPTHROW(call) \
@@ -16,7 +17,7 @@
if (res != mscclppSuccess && res != mscclppInProgress) { \
throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \
} \
} while (0);
} while (false)
#define CUDATHROW(cmd) \
do { \
@@ -26,4 +27,14 @@
} \
} while (false)
#define CUTHROW(cmd) \
do { \
CUresult err = cmd; \
if (err != CUDA_SUCCESS) { \
const char* errStr; \
cuGetErrorString(err, &errStr); \
throw std::runtime_error(std::string("Cu failure '") + std::string(errStr) + "'"); \
} \
} while (false)
#endif

View File

@@ -7,15 +7,16 @@
#ifndef MSCCLPP_COMM_H_
#define MSCCLPP_COMM_H_
#include "ib.h"
#include "ib.hpp"
#include "proxy.h"
#include <memory>
#include <vector>
#define MAXCONNECTIONS 64
struct mscclppBufferRegistration
{
void *data;
void* data;
uint64_t size;
};
@@ -31,7 +32,7 @@ struct mscclppConn
std::vector<mscclppBufferRegistration> bufferRegistrations;
std::vector<mscclppBufferRegistration> remoteBufferRegistrations;
struct mscclppIbContext* ibCtx;
mscclpp::IbCtx* ibCtx;
#if defined(ENABLE_NPKIT)
std::vector<uint64_t> npkitUsedReqIds;
std::vector<uint64_t> npkitFreeReqIds;
@@ -57,7 +58,7 @@ struct mscclppComm
// Flag to ask MSCCLPP kernels to abort
volatile uint32_t* abortFlag;
struct mscclppIbContext* ibContext[MSCCLPP_IB_MAX_DEVS];
std::unique_ptr<mscclpp::IbCtx> ibContext[MSCCLPP_IB_MAX_DEVS];
struct mscclppProxyState* proxyState[MSCCLPP_PROXY_MAX_NUM];
};

View File

@@ -1,23 +1,32 @@
#ifndef MSCCL_COMMUNICATOR_HPP_
#define MSCCL_COMMUNICATOR_HPP_
#include "mscclpp.hpp"
#include "ib.hpp"
#include "mscclpp.h"
#include "mscclpp.hpp"
#include "proxy.hpp"
#include <memory>
#include <unordered_map>
namespace mscclpp {
struct Communicator::Impl {
mscclppComm_t comm;
std::vector<std::shared_ptr<HostConnection>> connections;
Proxy proxy;
class ConnectionBase;
Impl();
struct Communicator::Impl
{
std::vector<std::shared_ptr<ConnectionBase>> connections_;
std::vector<std::shared_ptr<Setuppable>> toSetup_;
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
std::shared_ptr<BaseBootstrap> bootstrap_;
std::vector<uint64_t> rankToHash_;
~Impl();
Impl(std::shared_ptr<BaseBootstrap> bootstrap);
friend class HostConnection;
~Impl();
IbCtx* getIbContext(Transport ibTransport);
};
} // namespace mscclpp
#endif
#endif // MSCCL_COMMUNICATOR_HPP_

View File

@@ -0,0 +1,69 @@
#ifndef MSCCLPP_CONNECTION_HPP_
#define MSCCLPP_CONNECTION_HPP_
#include "communicator.hpp"
#include "ib.hpp"
#include "mscclpp.hpp"
#include <cuda_runtime.h>
namespace mscclpp {
// TODO: Add functionality to these classes for Communicator to do connectionSetup
class ConnectionBase : public Connection, public Setuppable
{
int remoteRank_;
int tag_;
public:
ConnectionBase(int remoteRank, int tag);
int remoteRank() override;
int tag() override;
};
class CudaIpcConnection : public ConnectionBase
{
cudaStream_t stream;
public:
CudaIpcConnection(int remoteRank, int tag);
~CudaIpcConnection();
Transport transport() override;
Transport remoteTransport() override;
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void flush() override;
};
class IBConnection : public ConnectionBase
{
Transport transport_;
Transport remoteTransport_;
IbQp* qp;
int numSignaledSends;
public:
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
Transport transport() override;
Transport remoteTransport() override;
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void flush() override;
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
};
} // namespace mscclpp
#endif // MSCCLPP_CONNECTION_HPP_

52
src/include/epoch.hpp Normal file
View File

@@ -0,0 +1,52 @@
#ifndef MSCCLPP_EPOCH_HPP_
#define MSCCLPP_EPOCH_HPP_
#include "mscclpp.hpp"
namespace mscclpp {
struct alignas(16) EpochIds
{
uint64_t outbound_;
uint64_t inboundReplica_;
};
struct DeviceEpoch
{
#ifdef __CUDACC__
__forceinline__ __device__ void wait()
{
(*expectedInboundEpochId_) += 1;
while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_));
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(epochIds_->outbound_) += 1;
}
#endif // __CUDACC__
EpochIds* epochIds_;
uint64_t* expectedInboundEpochId_;
};
class Epoch
{
std::shared_ptr<Connection> connection_;
DeviceEpoch device_;
RegisteredMemory localEpochIdsRegMem_;
NonblockingFuture<RegisteredMemory> remoteEpochIdsRegMem_;
public:
Epoch(Communicator& communicator, std::shared_ptr<Connection> connection);
Epoch(const Epoch&) = delete;
~Epoch();
void signal();
DeviceEpoch deviceEpoch() { return device_; }
};
} // namespace mscclpp
#endif // MSCCLPP_EPOCH_HPP_

View File

@@ -1,22 +0,0 @@
#ifndef MSCCLPP_HOST_CONNECTION_HPP_
#define MSCCLPP_HOST_CONNECTION_HPP_
#include "mscclpp.hpp"
#include "mscclpp.h"
#include "comm.h"
namespace mscclpp {
struct HostConnection::Impl {
Communicator* comm;
mscclppConn* conn;
mscclppHostConn_t* hostConn;
Impl(Communicator* comm, mscclppConn* conn);
~Impl();
};
} // namespace mscclpp
#endif // MSCCLPP_HOST_CONNECTION_HPP_

View File

@@ -1,69 +0,0 @@
#ifndef MSCCLPP_IB_H_
#define MSCCLPP_IB_H_
#include "mscclpp.h"
#include <infiniband/verbs.h>
#include <list>
#include <memory>
#include <string>
#define MSCCLPP_IB_CQ_SIZE 1024
#define MSCCLPP_IB_CQ_POLL_NUM 4
#define MSCCLPP_IB_MAX_SENDS 64
#define MSCCLPP_IB_MAX_DEVS 8
// QP info to be shared with the remote peer
struct mscclppIbQpInfo
{
uint16_t lid;
uint8_t port;
uint8_t linkLayer;
uint32_t qpn;
uint64_t spn;
ibv_mtu mtu;
};
// IB queue pair
struct mscclppIbQp
{
struct ibv_qp* qp;
struct mscclppIbQpInfo info;
struct ibv_send_wr* wrs;
struct ibv_sge* sges;
struct ibv_cq* cq;
struct ibv_wc* wcs;
int wrn;
int rtr(const mscclppIbQpInfo* info);
int rts();
int stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled);
int stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData);
int postSend();
int postRecv(uint64_t wrId);
int pollCq();
};
// Holds resources of a single IB device.
struct mscclppIbContext
{
struct ibv_context* ctx;
struct ibv_pd* pd;
int* ports;
int nPorts;
struct mscclppIbQp* qps;
int nQps;
int maxQps;
struct mscclppIbMr* mrs;
int nMrs;
int maxMrs;
};
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName);
mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx);
mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port = -1);
mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size,
struct mscclppIbMr** ibMr);
#endif

108
src/include/ib.hpp Normal file
View File

@@ -0,0 +1,108 @@
#ifndef MSCCLPP_IB_HPP_
#define MSCCLPP_IB_HPP_
#include <list>
#include <memory>
#include <string>
#define MSCCLPP_IB_CQ_SIZE 1024
#define MSCCLPP_IB_CQ_POLL_NUM 1
#define MSCCLPP_IB_MAX_SENDS 64
#define MSCCLPP_IB_MAX_DEVS 8
namespace mscclpp {
struct IbMrInfo
{
uint64_t addr;
uint32_t rkey;
};
class IbMr
{
public:
~IbMr();
IbMrInfo getInfo() const;
const void* getBuff() const;
uint32_t getLkey() const;
private:
IbMr(void* pd, void* buff, std::size_t size);
void* mr;
void* buff;
std::size_t size;
friend class IbCtx;
};
// QP info to be shared with the remote peer
struct IbQpInfo
{
uint16_t lid;
uint8_t port;
uint8_t linkLayer;
uint32_t qpn;
uint64_t spn;
int mtu;
};
class IbQp
{
public:
~IbQp();
void rtr(const IbQpInfo& info);
void rts();
int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled);
int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
uint64_t dstOffset, bool signaled, unsigned int immData);
void postSend();
void postRecv(uint64_t wrId);
int pollCq();
IbQpInfo& getInfo();
const void* getWc(int idx) const;
private:
IbQp(void* ctx, void* pd, int port);
IbQpInfo info;
void* qp;
void* cq;
void* wcs;
void* wrs;
void* sges;
int wrn;
friend class IbCtx;
};
class IbCtx
{
public:
IbCtx(const std::string& devName);
~IbCtx();
IbQp* createQp(int port = -1);
const IbMr* registerMr(void* buff, std::size_t size);
const std::string& getDevName() const;
private:
bool isPortUsable(int port) const;
int getAnyActivePort() const;
const std::string devName;
void* ctx;
void* pd;
std::list<std::unique_ptr<IbQp>> qps;
std::list<std::unique_ptr<IbMr>> mrs;
};
} // namespace mscclpp
#endif // MSCCLPP_IB_HPP_

View File

@@ -191,7 +191,8 @@ struct mscclppHostConn
{
virtual ~mscclppHostConn() = default;
virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0;
virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize) = 0;
virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
uint64_t dataSize) = 0;
virtual void signal() = 0;
virtual void wait() = 0;
virtual void flush() = 0;
@@ -207,25 +208,10 @@ typedef struct
char internal[MSCCLPP_UNIQUE_ID_BYTES];
} mscclppUniqueId;
// MR info to be shared with the remote peer
struct mscclppIbMrInfo
{
uint64_t addr;
uint32_t rkey;
};
// IB memory region
struct mscclppIbMr
{
struct ibv_mr* mr;
void* buff;
struct mscclppIbMrInfo info;
};
struct mscclppRegisteredMemoryP2P
{
void* remoteBuff;
mscclppIbMr* IbMr;
const void* IbMr;
};
struct mscclppRegisteredMemory
@@ -247,7 +233,6 @@ typedef enum
mscclppNumResults = 8
} mscclppResult_t;
/* Create a unique ID for communication. Only needs to be called by one process.
* Use with mscclppCommInitRankFromId().
* All processes need to provide the same ID to mscclppCommInitRankFromId().
@@ -358,7 +343,8 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
*/
mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev = 0);
mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag,
mscclppTransport_t transportType, const char* ibDev = 0);
/* Register a buffer for use with a connection.
*
@@ -371,7 +357,8 @@ mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank,
* Outputs:
* handle: a handle to the buffer registration
*/
mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle);
mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize,
mscclppBufferHandle_t* handle);
/* Establish all connections declared by mscclppConnect(). This function must be called after all mscclppConnect()
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.

View File

@@ -6,22 +6,17 @@
#define MSCCLPP_PATCH 0
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered.
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
#define MSCCLPP_PROXY_FIFO_SIZE 128
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
#include <vector>
#include <bitset>
#include <memory>
#include <string>
#include <functional>
#include <mscclppfifo.hpp>
#include <vector>
#include <future>
namespace mscclpp {
#define MSCCLPP_UNIQUE_ID_BYTES 128
struct UniqueId {
struct UniqueId
{
char internal[MSCCLPP_UNIQUE_ID_BYTES];
};
@@ -36,6 +31,21 @@ public:
virtual void recv(void* data, int size, int peer, int tag) = 0;
virtual void allGather(void* allData, int size) = 0;
virtual void barrier() = 0;
// TODO: move implementations of these helpers out of this header
void send(const std::vector<char>& data, int peer, int tag)
{
size_t size = data.size();
send((void*)&size, sizeof(size_t), peer, tag);
send((void*)data.data(), data.size(), peer, tag+1);
}
void recv(std::vector<char>& data, int peer, int tag)
{
size_t size;
recv((void*)&size, sizeof(size_t), peer, tag);
data.resize(size);
recv((void*)data.data(), data.size(), peer, tag+1);
}
};
class Bootstrap : public BaseBootstrap
@@ -61,369 +71,6 @@ private:
std::unique_ptr<Impl> pimpl_;
};
struct alignas(16) SignalEpochId {
// every signal(), increaments this and either:
// 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy
// 2) gpu thread directly writes it to remoteSignalEpochId->device
uint64_t device;
// signal() function triggers the cpu proxy thread to write to it
uint64_t proxy;
};
using ChannelTriggerType = uint64_t;
const ChannelTriggerType channelTriggerData = 0x1;
const ChannelTriggerType channelTriggerFlag = 0x2;
const ChannelTriggerType channelTriggerSync = 0x4;
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
// mapping to the actual
using BufferHandle = uint32_t;
#define MSCCLPP_BITS_SIZE 32
#define MSCCLPP_BITS_OFFSET 32
#define MSCCLPP_BITS_BUFFER_HANDLE 8
#define MSCCLPP_BITS_TYPE 3
#define MSCCLPP_BITS_CONNID 10
// this is the basic structure of each work element in the fifo
// the summation of number of bits must be 128 or less
union ChannelTrigger {
ProxyTrigger value;
struct
{
// first 64 bits: value[0]
uint64_t size : MSCCLPP_BITS_SIZE;
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
// second 64 bits: value[1]
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
uint64_t type : MSCCLPP_BITS_TYPE;
uint64_t connId : MSCCLPP_BITS_CONNID;
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
} fields;
#ifdef __CUDACC__
__device__ ChannelTrigger() {}
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
__device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) {
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
}
#endif // __CUDACC__
};
struct ConnectionEpoch {
#ifdef __CUDACC__
__forceinline__ __device__ void wait()
{
(*waitEpochId) += 1;
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
;
}
__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
}
#endif // __CUDACC__
SignalEpochId* localSignalEpochId;
// used by the signal() function directly from gpu
SignalEpochId* remoteSignalEpochId;
// every wait(), increments this and then the gpu waits for either:
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
uint64_t* waitEpochId;
};
class HostConnection {
struct Impl;
public:
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
HostConnection(std::unique_ptr<Impl>);
~HostConnection();
int getId();
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
* in the communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
BufferHandle registerBuffer(void* data, uint64_t size);
/* Get the number of times registerBuffer(...) was called.
*
* Returns: the number of buffers registered
*/
int numLocalBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer
*/
BufferHandle getLocalBuffer(int index);
/* Get the number of times registerBuffer(...) was called on the remote peer.
*
* Returns: the number of buffers registered on the remote peer
*/
int numRemoteBuffers();
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
*
* Inputs:
* index: the index of the handle to get
*
* Returns: a handle to the buffer on the remote peer
*/
BufferHandle getRemoteBuffer(int index);
ConnectionEpoch getEpoch();
DeviceProxyFifo getDeviceFifo();
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
void signal();
void flush();
void wait();
private:
std::unique_ptr<Impl> pimpl;
friend class Communicator;
};
/***************************************************************************************************************
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.
* The communication API is one-sided meaning that for every single data transfer, only one side
* needs to execute unlike a two-sided communication stack such as NCCL where both sides
* need to execute a send and a receive instruction, respectively, for every transfer.
*
* A connection is uniquely identified by the (remoteRank, tag) pair at an endpoint.
* The two endpoints register buffers of the same size with the connection.
*
* The endpoints provide the remoteRank, tag, and the buffer when registering a connection with msccppConnect().
*
* mscllppConnectionSetup() sets up all the registered connections.
*
***************************************************************************************************************
* A proxy thread running on the CPU is necessary to perform transfers using InfiniBand or the DMA engine.
* The current implementation uses a single proxy thread per context - one IB connection or DMA engine per node.
* Thus multiple threadblocks using different connections might use the same CPU proxy thread.
*
* Before using any of functionality of connections, mscclppProxyLaunch needs to be called to spawn the
* proxy threads. There are currently two types of connections:
*
* P2P via NVLink: the DMA engine can perform the copy between the buffers. DMA engine has higher latency
* but has a higher bandwidth and costs no compute cycles on the GPU.
*
* InfiniBand: the RDMA engine copies the data over MLX devices.
*
***************************************************************************************************************
* At the runtime, a GPU kernel has access to a mscclppDevConn object that provides the following functions:
*
* put(): [non-blocking] the sender initiates a data transfer to the receiver.
*
* signal(): [non-blocking] the sender signals the receiver that data is ready to be consumed.
*
* flush(): [blocking] the sender waits for all the data transfers to complete
*
* wait(): [blocking] the reciever waits on the signal() to start reading the data.
*
* The sender should not reuse the buffer till the flush() returns.
* The receiver should only access the data after the wait() returns.
*
* putWithSignal(): the sender initiates a data transfer and signals the receiver that data is ready to be consumed.
* This is an optimized version of a put() followed by a signal().
*
* These functions hide the complexity of syncrhonization between the two GPUs and the CPU proxy thread.
* Example:
*
* // sender GPU
* devConn.put(data1)
* // not OK to write to data1
* devConn.put(data2)
* // not OK to write to data1, data2
* devConn.put(data3) // receiver GPU
* // not OK to write to data1, data2, data3 // not OK to read data1, data2, data3
* devConn.signal() -------------------------------> devConn.wait()
* // not OK to write to data1, data2, data3 // OK to read data1, data2, data3
* devConn.flush()
* // OK to write to data1, data2, data3
*
*
* The two endpoint can concurrently use the same connection provided they are writing (puts) on different
* indices in the registered buffer.
**************************************************************************************************************/
struct DeviceConnection {
DeviceConnection() = default;
DeviceConnection(HostConnection& hostConn)
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
fifo(hostConn.getDeviceFifo()) {}
DeviceConnection(const DeviceConnection& other) = default;
DeviceConnection& operator=(DeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
put(dst, offset, src, offset, size);
}
__forceinline__ __device__ void signal()
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value);
}
__forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
putWithSignal(dst, offset, src, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
{
epochIncrement();
uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value);
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(dst, offset, src, offset, size);
}
__forceinline__ __device__ void flush()
{
uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value);
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
;
}
__forceinline__ __device__ void wait()
{
epoch.wait();
}
__forceinline__ __device__ void epochIncrement()
{
epoch.epochIncrement();
}
#endif // __CUDACC__
int connectionId;
ConnectionEpoch epoch;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
DeviceProxyFifo fifo;
};
struct SimpleDeviceConnection {
SimpleDeviceConnection() = default;
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
dst = hostConn.getRemoteBuffer(0);
src = hostConn.getLocalBuffer(0);
}
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
#ifdef __CUDACC__
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.put(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
{
put(offset, offset, size);
}
__forceinline__ __device__ void signal()
{
devConn.signal();
}
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignal(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
{
putWithSignal(offset, offset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
{
devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size);
}
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
{
putWithSignalAndFlush(offset, offset, size);
}
__forceinline__ __device__ void flush()
{
devConn.flush();
}
__forceinline__ __device__ void wait()
{
devConn.wait();
}
__forceinline__ __device__ void epochIncrement()
{
devConn.epochIncrement();
}
#endif // __CUDACC__
DeviceConnection devConn;
BufferHandle dst;
BufferHandle src;
};
/* Create a unique ID for communication. Only needs to be called by one process.
* Use with mscclppCommInitRankFromId().
* All processes need to provide the same ID to mscclppCommInitRankFromId().
@@ -433,120 +80,300 @@ struct SimpleDeviceConnection {
*/
std::unique_ptr<UniqueId> getUniqueId();
/* Transport Types */
enum class TransportType : uint8_t {
P2P = 0,
IB = 1,
enum class Transport
{
Unknown,
CudaIpc,
IB0,
IB1,
IB2,
IB3,
IB4,
IB5,
IB6,
IB7,
NumTransports
};
class Communicator {
public:
namespace detail {
const size_t TransportFlagsSize = 10;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
"TransportFlagsSize must match the number of transports");
using TransportFlagsBase = std::bitset<TransportFlagsSize>;
} // namespace detail
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
*
* Inputs:
* nranks: number of ranks in the communicator
* ipPortPair: a string of the form "ip:port" that represents the address of the root process
* rank: rank of the calling process
*/
Communicator(int nranks, const char* ipPortPair, int rank);
/* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that
* id is provided by the user by calling getUniqueId()
*
* Inputs:
* nranks: number of ranks in the communicator
* id: the unique ID to be used for communication
* rank: rank of the calling process
*/
Communicator(int nranks, UniqueId id, int rank);
class TransportFlags : private detail::TransportFlagsBase
{
public:
TransportFlags() = default;
TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast<size_t>(transport))
{
}
bool has(Transport transport) const
{
return detail::TransportFlagsBase::test(static_cast<size_t>(transport));
}
bool none() const
{
return detail::TransportFlagsBase::none();
}
bool any() const
{
return detail::TransportFlagsBase::any();
}
bool all() const
{
return detail::TransportFlagsBase::all();
}
size_t count() const
{
return detail::TransportFlagsBase::count();
}
TransportFlags& operator|=(TransportFlags other)
{
detail::TransportFlagsBase::operator|=(other);
return *this;
}
TransportFlags operator|(TransportFlags other) const
{
return TransportFlags(*this) |= other;
}
TransportFlags operator|(Transport transport) const
{
return *this | TransportFlags(transport);
}
TransportFlags& operator&=(TransportFlags other)
{
detail::TransportFlagsBase::operator&=(other);
return *this;
}
TransportFlags operator&(TransportFlags other) const
{
return TransportFlags(*this) &= other;
}
TransportFlags operator&(Transport transport) const
{
return *this & TransportFlags(transport);
}
TransportFlags& operator^=(TransportFlags other)
{
detail::TransportFlagsBase::operator^=(other);
return *this;
}
TransportFlags operator^(TransportFlags other) const
{
return TransportFlags(*this) ^= other;
}
TransportFlags operator^(Transport transport) const
{
return *this ^ TransportFlags(transport);
}
TransportFlags operator~() const
{
return TransportFlags(*this).flip();
}
bool operator==(TransportFlags other) const
{
return detail::TransportFlagsBase::operator==(other);
}
bool operator!=(TransportFlags other) const
{
return detail::TransportFlagsBase::operator!=(other);
}
detail::TransportFlagsBase toBitset() const
{
return *this;
}
private:
TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset)
{
}
};
inline TransportFlags operator|(Transport transport1, Transport transport2)
{
return TransportFlags(transport1) | transport2;
}
inline TransportFlags operator&(Transport transport1, Transport transport2)
{
return TransportFlags(transport1) & transport2;
}
inline TransportFlags operator^(Transport transport1, Transport transport2)
{
return TransportFlags(transport1) ^ transport2;
}
const TransportFlags NoTransports = TransportFlags();
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 |
Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
int getIBDeviceCount();
std::string getIBDeviceName(Transport ibTransport);
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);
class Communicator;
class Connection;
class RegisteredMemory
{
struct Impl;
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated lazily.
std::shared_ptr<Impl> pimpl;
public:
RegisteredMemory() = default;
RegisteredMemory(std::shared_ptr<Impl> pimpl);
~RegisteredMemory();
void* data();
size_t size();
int rank();
TransportFlags transports();
std::vector<char> serialize();
static RegisteredMemory deserialize(const std::vector<char>& data);
friend class Connection;
friend class Communicator;
};
class Connection
{
public:
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) = 0;
virtual void flush() = 0;
virtual int remoteRank() = 0;
virtual int tag() = 0;
virtual Transport transport() = 0;
virtual Transport remoteTransport() = 0;
protected:
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory&);
};
struct Setuppable
{
virtual void beginSetup(std::shared_ptr<BaseBootstrap>) {}
virtual void endSetup(std::shared_ptr<BaseBootstrap>) {}
};
template<typename T>
class NonblockingFuture
{
std::shared_future<T> future;
public:
NonblockingFuture() = default;
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}
NonblockingFuture(const NonblockingFuture&) = default;
bool ready() const
{
return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
}
T get()
{
if (!ready())
throw std::runtime_error("NonblockingFuture::get() called before ready");
return future.get();
}
};
class Communicator
{
public:
/* Initialize the communicator.
*
* Inputs:
* bootstrap: an implementation of the of BaseBootstrap that the communicator will use
*/
Communicator(std::shared_ptr<BaseBootstrap> bootstrap);
~Communicator();
/* Ring-based AllGather through the bootstrap socket.
*
* Inputs:
* data: data array to be gathered where `[r*size, (r+1)*size)` is the data for rank `r`
* size: data size per rank
*/
void bootstrapAllGather(void* data, int size);
/* A no-op function that is used to synchronize all processes via a bootstrap allgather*/
void bootstrapBarrier();
/* Return the bootstrapper held by this communicator. */
std::shared_ptr<BaseBootstrap> bootstrapper();
/* Register a region of GPU memory for use in this communicator.
*
* Inputs:
* data: base pointer to the memory
* size: size of the memory region in bytes
*
* Returns: a handle to the buffer
*/
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag);
NonblockingFuture<RegisteredMemory> recvMemoryOnSetup(int remoteRank, int tag);
/* Connect to a remote rank. This function only prepares metadata for connection. The actual connection
* is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection
* from rank i to remote rank j needs to have a counterpart from rank j to rank i.
* Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages
* and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has
* security risks if the devConn's accesses are given to a malicious process.
*
* Inputs:
* remoteRank: the rank of the remote process
* tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be
* used to identify the connection inside a GPU kernel.
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
*/
std::shared_ptr<HostConnection> connect(int remoteRank, int tag, TransportType transportType, const char* ibDev = 0);
* is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection
* from rank i to remote rank j needs to have a counterpart from rank j to rank i.
* Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages
* and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has
* security risks if the devConn's accesses are given to a malicious process.
*
* Inputs:
* remoteRank: the rank of the remote process
* tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be
* used to identify the connection inside a GPU kernel.
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
*/
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport);
/* Establish all connections created by mscclppConnect(). This function must be called after all mscclppConnect()
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
*/
void connectionSetup();
/* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */
void startProxying();
/* Add a custom Setuppable object to a list of objects to be setup later, when setup() is called. */
void addSetup(std::shared_ptr<Setuppable> setuppable);
/* Stop proxy thread(s). */
void stopProxying();
/* Return the rank of the calling process.
*
* Outputs:
* rank: the rank of the calling process
*/
int rank();
/* Return the number of ranks of the communicator.
*
* Outputs:
* size: the number of ranks of the communicator
*/
int size();
/* Setup all objects that have registered for setup. This includes any connections created by connect(). */
void setup();
struct Impl;
private:
std::unique_ptr<Impl> pimpl;
friend class HostConnection;
};
enum class ProxyHandlerResult {
Continue,
FlushFifoTailAndContinue,
Stop,
};
class Proxy;
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
class Proxy {
public:
Proxy(ProxyHandler handler);
~Proxy();
void start();
void stop();
HostProxyFifo& fifo();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
};
} // namespace mscclpp
namespace std {
template <> struct hash<mscclpp::TransportFlags>
{
size_t operator()(const mscclpp::TransportFlags& flags) const
{
return hash<mscclpp::detail::TransportFlagsBase>()(flags.toBitset());
}
};
} // namespace std
#endif // MSCCLPP_H_

View File

@@ -1,13 +1,19 @@
#ifndef MSCCLPPFIFO_HPP_
#define MSCCLPPFIFO_HPP_
#include <stdint.h>
#include <functional>
#include <memory>
#include <stdint.h>
namespace mscclpp {
struct alignas(16) ProxyTrigger {
// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered.
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
#define MSCCLPP_PROXY_FIFO_SIZE 128
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
struct alignas(16) ProxyTrigger
{
uint64_t fst, snd;
};
@@ -24,7 +30,8 @@ struct alignas(16) ProxyTrigger {
* Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
* for the tail as there is usually enough space for device threads to push their work into.
*/
struct DeviceProxyFifo {
struct DeviceProxyFifo
{
#ifdef __CUDACC__
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger)
{
@@ -34,32 +41,31 @@ struct DeviceProxyFifo {
while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
;
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]);
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr),
"l"(trigger.fst), "l"(trigger.snd));
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
return curFifoHead;
}
#endif // __CUDACC__
ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
// occasionally to device
uint64_t* head; // Allocated on device. Only accessed by device
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
// occasionally to device
uint64_t* head; // Allocated on device. Only accessed by device
};
class HostProxyFifo
{
public:
HostProxyFifo();
~HostProxyFifo();
void poll(ProxyTrigger *trigger);
void poll(ProxyTrigger* trigger);
void pop();
void flushTail(bool sync = false);
DeviceProxyFifo toDevice();
DeviceProxyFifo deviceFifo();
private:
struct Impl;

View File

@@ -59,8 +59,8 @@ struct mscclppProxyState
mscclppProxyRunState_t run;
int numaNodeToBind;
struct mscclppIbContext* ibContext; // For IB connection only
cudaStream_t p2pStream; // for P2P DMA engine only
mscclpp::IbCtx* ibContext; // For IB connection only
cudaStream_t p2pStream; // for P2P DMA engine only
struct mscclppProxyFifo fifo;
};

40
src/include/proxy.hpp Normal file
View File

@@ -0,0 +1,40 @@
#ifndef MSCCLPP_PROXY_HPP_
#define MSCCLPP_PROXY_HPP_
#include "mscclppfifo.hpp"
#include <functional>
#include <memory>
namespace mscclpp {
enum class ProxyHandlerResult
{
Continue,
FlushFifoTailAndContinue,
Stop,
};
class Proxy;
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
class Proxy
{
public:
Proxy(ProxyHandler handler, std::function<void()> threadInit);
Proxy(ProxyHandler handler);
~Proxy();
void start();
void stop();
HostProxyFifo& fifo();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
};
} // namespace mscclpp
#endif // MSCCLPP_PROXY_HPP_

View File

@@ -0,0 +1,55 @@
#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_
#define MSCCLPP_REGISTERED_MEMORY_HPP_
#include "communicator.hpp"
#include "ib.hpp"
#include "mscclpp.h"
#include "mscclpp.hpp"
#include <cuda_runtime.h>
namespace mscclpp {
struct TransportInfo
{
Transport transport;
// TODO: rewrite this using std::variant or something
bool ibLocal;
union {
struct {
cudaIpcMemHandle_t cudaIpcBaseHandle;
size_t cudaIpcOffsetFromBase;
};
struct {
const IbMr* ibMr;
IbMrInfo ibMrInfo;
};
};
};
struct RegisteredMemory::Impl
{
void* data;
size_t size;
int rank;
uint64_t hostHash;
TransportFlags transports;
std::vector<TransportInfo> transportInfos;
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
Impl(const std::vector<char>& data);
TransportInfo& getTransportInfo(Transport transport)
{
for (auto& entry : transportInfos) {
if (entry.transport == transport) {
return entry;
}
}
throw std::runtime_error("Transport data not found");
}
};
} // namespace mscclpp
#endif // MSCCLPP_REGISTERED_MEMORY_HPP_

View File

@@ -0,0 +1,52 @@
#ifndef MSCCLPP_REGISTERED_PTR_HPP_
#define MSCCLPP_REGISTERED_PTR_HPP_
namespace mscclpp {
template <typename T> class RegisteredPtr
{
RegisteredMemory memory;
size_t offset;
public:
RegisteredPtr(RegisteredMemory memory, size_t offset) : memory(memory), offset(offset)
{
}
RegisteredPtr(RegisteredMemory memory) : RegisteredPtr(memory, 0)
{
}
~RegisteredPtr()
{
}
RegisteredMemory memory()
{
return memory;
}
T* data()
{
return reinterpret_cast<T*>(memory.data());
}
size_t size()
{
return memory.size() / sizeof(T);
}
size_t offset()
{
return offset;
}
RegisteredPtr<T> operator+(size_t offset)
{
return RegisteredPtr<T>(memory, this->offset + offset);
}
// TODO: all other relevant overloads
};
} // namespace mscclpp
#endif // MSCCLPP_REGISTERED_PTR_HPP_

54
src/include/utils.hpp Normal file
View File

@@ -0,0 +1,54 @@
#ifndef MSCCLPP_UTILS_HPP_
#define MSCCLPP_UTILS_HPP_
#include <chrono>
#include <stdio.h>
namespace mscclpp {
struct Timer
{
std::chrono::steady_clock::time_point start;
Timer()
{
start = std::chrono::steady_clock::now();
}
int64_t elapsed()
{
auto end = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
}
void reset()
{
start = std::chrono::steady_clock::now();
}
void print(const char* name)
{
auto end = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
printf("%s: %ld us\n", name, elapsed);
}
};
struct ScopedTimer
{
Timer timer;
const char* name;
ScopedTimer(const char* name) : name(name)
{
}
~ScopedTimer()
{
timer.print(name);
}
};
} // namespace mscclpp
#endif // MSCCLPP_UTILS_HPP_

View File

@@ -6,6 +6,7 @@
#if defined(MSCCLPP_USE_GDRCOPY)
#include "gdr.h"
#endif
#include "infiniband/verbs.h"
#include "mscclpp.h"
#include <map>
#include <sstream>
@@ -191,7 +192,7 @@ MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
if (comm->ibContext[i]) {
MSCCLPPCHECK(mscclppIbContextDestroy(comm->ibContext[i]));
comm->ibContext[i].reset(nullptr);
}
}
@@ -326,7 +327,8 @@ struct mscclppHostP2PConn : mscclppHostConn
{
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
}
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
uint64_t dataSize)
{
void* srcBuff = (void*)((char*)conn->bufferRegistrations[src].data + srcDataOffset);
void* dstBuff = (void*)((char*)conn->remoteBufferRegistrations[dst].data + dstDataOffset);
@@ -364,26 +366,20 @@ struct mscclppHostIBConn : mscclppHostConn
{
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
}
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
uint64_t dataSize)
{
this->ibQp->stageSend(this->ibMrs[src], &this->remoteIbMrInfos[dst], (uint32_t)dataSize,
this->ibQp->stageSend(this->ibMrs[src], this->remoteIbMrInfos[dst], (uint32_t)dataSize,
/*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false);
int ret = this->ibQp->postSend();
if (ret != 0) {
// Return value is errno.
WARN("data postSend failed: errno %d", ret);
}
this->ibQp->postSend();
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize);
}
void signal()
{
// My local device flag is copied to the remote's proxy flag
this->ibQp->stageSend(this->ibMrs[0], &this->remoteIbMrInfos[0], sizeof(uint64_t),
this->ibQp->stageSend(this->ibMrs[0], this->remoteIbMrInfos[0], sizeof(uint64_t),
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
int ret = this->ibQp->postSend();
if (ret != 0) {
WARN("flag postSend failed: errno %d", ret);
}
this->ibQp->postSend();
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
}
void wait()
@@ -399,15 +395,11 @@ struct mscclppHostIBConn : mscclppHostConn
continue;
}
for (int i = 0; i < wcNum; ++i) {
struct ibv_wc* wc = &this->ibQp->wcs[i];
struct ibv_wc* wc = (struct ibv_wc*)this->ibQp->getWc(i);
if (wc->status != IBV_WC_SUCCESS) {
WARN("wc status %d", wc->status);
continue;
}
if (wc->qp_num != this->ibQp->qp->qp_num) {
WARN("got wc of unknown qp_num %d", wc->qp_num);
continue;
}
if (wc->opcode == IBV_WC_RDMA_WRITE) {
isWaiting = false;
break;
@@ -418,12 +410,13 @@ struct mscclppHostIBConn : mscclppHostConn
}
mscclppConn* conn;
struct mscclppIbQp* ibQp;
std::vector<mscclppIbMr*> ibMrs;
std::vector<mscclppIbMrInfo> remoteIbMrInfos;
mscclpp::IbQp* ibQp;
std::vector<const mscclpp::IbMr*> ibMrs;
std::vector<mscclpp::IbMrInfo> remoteIbMrInfos;
};
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev)
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag,
mscclppTransport_t transportType, const char* ibDev)
{
// save this processes numa binding and set it to the one closest to the device
// so that all the allocation are close to the device
@@ -458,7 +451,7 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
if (firstNullIdx == -1) {
firstNullIdx = i;
}
} else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) {
} else if (strncmp(comm->ibContext[i]->getDevName().c_str(), ibDev, IBV_SYSFS_NAME_MAX) == 0) {
ibDevIdx = i;
break;
}
@@ -468,13 +461,10 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
if (ibDevIdx == -1) {
// Create a new context.
ibDevIdx = firstNullIdx;
if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) {
WARN("Failed to create IB context");
return mscclppInternalError;
}
comm->ibContext[ibDevIdx].reset(new mscclpp::IbCtx(std::string(ibDev)));
}
// Set the ib context for this conn
conn->ibCtx = comm->ibContext[ibDevIdx];
conn->ibCtx = comm->ibContext[ibDevIdx].get();
} else if (transportType == mscclppTransportP2P) {
// do the rest of the initialization later
@@ -563,7 +553,8 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
MSCCLPPCHECK(setNumaState(curProcessState));
mscclppBufferHandle_t signalHandle = -1;
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, sizeof(mscclppDevConnSignalEpochId), &signalHandle));
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId,
sizeof(mscclppDevConnSignalEpochId), &signalHandle));
if (signalHandle != 0) {
WARN("signal handle should be 0");
return mscclppInternalError;
@@ -592,7 +583,9 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
return mscclppSuccess;
}
MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle) {
MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff,
uint64_t buffSize, mscclppBufferHandle_t* handle)
{
if (connIdx >= comm->nConns) {
WARN("connIdx out of range");
return mscclppInvalidArgument;
@@ -609,35 +602,40 @@ MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t com
struct mscclppBufferRegistrationInfo
{
cudaIpcMemHandle_t cudaHandle;
mscclppIbMrInfo ibMrInfo;
mscclpp::IbMrInfo ibMrInfo;
uint64_t size;
};
struct connInfo
{
mscclppIbQpInfo infoQp;
mscclpp::IbQpInfo infoQp;
std::vector<mscclppBufferRegistrationInfo> bufferInfos;
struct header {
mscclppIbQpInfo infoQp;
struct header
{
mscclpp::IbQpInfo infoQp;
int numBufferInfos;
};
mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag) {
mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag)
{
header h;
h.infoQp = infoQp;
h.numBufferInfos = bufferInfos.size();
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header)));
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(),
bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
return mscclppSuccess;
}
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) {
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag)
{
header h;
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, &h, sizeof(header)));
infoQp = h.infoQp;
bufferInfos.resize(h.numBufferInfos);
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(),
bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
return mscclppSuccess;
}
};
@@ -650,7 +648,7 @@ mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*input
}
// Add all registered buffers
for (const auto &bufReg : conn->bufferRegistrations) {
for (const auto& bufReg : conn->bufferRegistrations) {
connInfo->bufferInfos.emplace_back();
CUDACHECK(cudaIpcGetMemHandle(&connInfo->bufferInfos.back().cudaHandle, bufReg.data));
connInfo->bufferInfos.back().size = bufReg.size;
@@ -672,7 +670,8 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/
// Open all remote registered buffers
for (size_t i = 0; i < connInfo->bufferInfos.size(); i++) {
mscclppBufferRegistration newBufReg;
CUDACHECK(cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess));
CUDACHECK(
cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess));
newBufReg.size = connInfo->bufferInfos[i].size;
conn->remoteBufferRegistrations.push_back(newBufReg);
}
@@ -683,8 +682,8 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/
}
conn->devConn->remoteSignalEpochId = (mscclppDevConnSignalEpochId*)conn->remoteBufferRegistrations[0].data;
// For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote buffer
// to the first remote data buffer
// For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote
// buffer to the first remote data buffer
if (conn->remoteBufferRegistrations.size() > 1) {
conn->devConn->remoteBuff = conn->remoteBufferRegistrations[1].data;
}
@@ -702,22 +701,20 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
devConn->remoteBuff = NULL;
devConn->remoteSignalEpochId = NULL;
struct mscclppIbContext* ibCtx = conn->ibCtx;
mscclpp::IbCtx* ibCtx = conn->ibCtx;
if (hostConn->ibQp == NULL) {
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp));
hostConn->ibQp = ibCtx->createQp();
}
// Add all registered buffers
for (const auto &bufReg : conn->bufferRegistrations) {
hostConn->ibMrs.emplace_back();
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, bufReg.data,
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibMrs.back()));
for (const auto& bufReg : conn->bufferRegistrations) {
hostConn->ibMrs.emplace_back(ibCtx->registerMr(bufReg.data, sizeof(struct mscclppDevConnSignalEpochId)));
connInfo->bufferInfos.emplace_back();
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->info;
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->getInfo();
connInfo->bufferInfos.back().size = bufReg.size;
}
connInfo->infoQp = hostConn->ibQp->info;
connInfo->infoQp = hostConn->ibQp->getInfo();
return mscclppSuccess;
}
@@ -728,14 +725,8 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
return mscclppInternalError;
}
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
if (hostConn->ibQp->rtr(&connInfo->infoQp) != 0) {
WARN("Failed to transition QP to RTR");
return mscclppInvalidUsage;
}
if (hostConn->ibQp->rts() != 0) {
WARN("Failed to transition QP to RTS");
return mscclppInvalidUsage;
}
hostConn->ibQp->rtr(connInfo->infoQp);
hostConn->ibQp->rts();
// No remote pointers to set with IB, so we just set the Mrs
@@ -764,7 +755,8 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
MSCCLPPCHECK(mscclppIbConnectionSetupStart(&cInfo, conn));
}
// TODO: from saemal: do we possibly deadlock if there are too many outstanding sends?
// MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
// MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo,
// sizeof(cInfo)));
MSCCLPPCHECK(cInfo.sendOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
}
@@ -788,25 +780,25 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
struct bufferInfo
{
cudaIpcMemHandle_t handleBuff;
mscclppIbMrInfo infoBuffMr;
mscclpp::IbMrInfo infoBuffMr;
};
MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
mscclppRegisteredMemory* regMem)
{
std::vector<struct mscclppIbMr*> ibMrs;
std::vector<const mscclpp::IbMr*> ibMrs;
for (int i = 0; i < comm->nConns; ++i) {
struct mscclppConn* conn = &comm->conns[i];
struct bufferInfo bInfo;
struct mscclppIbMr* ibBuffMr;
const mscclpp::IbMr* ibBuffMr;
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
if (conn->transport == mscclppTransportP2P) {
CUDACHECK(cudaIpcGetMemHandle(&bInfo.handleBuff, local_memory));
} else if (conn->transport == mscclppTransportIB) {
MSCCLPPCHECK(mscclppIbContextRegisterMr(conn->ibCtx, local_memory, size, &ibBuffMr));
bInfo.infoBuffMr = ibBuffMr->info;
ibMrs.push_back(ibBuffMr);
ibBuffMr = conn->ibCtx->registerMr(local_memory, size);
bInfo.infoBuffMr = ibBuffMr->getInfo();
ibMrs.emplace_back(ibBuffMr);
}
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo)));

View File

@@ -2,7 +2,7 @@
#include "checks.h"
#include "comm.h"
#include "debug.h"
#include "ib.h"
#include "ib.hpp"
#include "socket.h"
#include <emmintrin.h>

View File

@@ -1,8 +1,10 @@
#include "proxy.hpp"
#include "api.h"
#include "mscclpp.hpp"
#include "utils.h"
#include "api.h"
#include <thread>
#include "utils.hpp"
#include <atomic>
#include <thread>
namespace mscclpp {
@@ -10,30 +12,41 @@ const int ProxyStopCheckPeriod = 1000;
const int ProxyFlushPeriod = 4;
struct Proxy::Impl {
struct Proxy::Impl
{
ProxyHandler handler;
std::function<void()> threadInit;
HostProxyFifo fifo;
std::thread service;
std::atomic_bool running;
Impl(ProxyHandler handler) : handler(handler), running(false) {}
Impl(ProxyHandler handler, std::function<void()> threadInit) : handler(handler), threadInit(threadInit), running(false)
{
}
};
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) {
pimpl = std::make_unique<Impl>(handler);
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit)
{
pimpl = std::make_unique<Impl>(handler, threadInit);
}
MSCCLPP_API_CPP Proxy::~Proxy() {
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {})
{
}
MSCCLPP_API_CPP Proxy::~Proxy()
{
if (pimpl) {
stop();
}
}
MSCCLPP_API_CPP void Proxy::start() {
MSCCLPP_API_CPP void Proxy::start()
{
pimpl->running = true;
pimpl->service = std::thread([this] {
// from this point on, proxy thread will stay close to the device
// PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this
pimpl->threadInit();
ProxyHandler handler = this->pimpl->handler;
HostProxyFifo& fifo = this->pimpl->fifo;
@@ -52,7 +65,7 @@ MSCCLPP_API_CPP void Proxy::start() {
// Poll to see if we are ready to send anything
fifo.poll(&trigger);
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
continue; // there is one in progress
continue; // there is one in progress
}
ProxyHandlerResult result = handler(trigger);
@@ -83,14 +96,16 @@ MSCCLPP_API_CPP void Proxy::start() {
});
}
MSCCLPP_API_CPP void Proxy::stop() {
MSCCLPP_API_CPP void Proxy::stop()
{
pimpl->running = false;
if (pimpl->service.joinable()) {
pimpl->service.join();
}
}
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() {
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo()
{
return pimpl->fifo;
}

164
src/registered_memory.cc Normal file
View File

@@ -0,0 +1,164 @@
#include "registered_memory.hpp"
#include "api.h"
#include "checks.hpp"
#include "utils.h"
#include <algorithm>
#include <cuda.h>
namespace mscclpp {
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl)
: data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports)
{
if (transports.has(Transport::CudaIpc)) {
TransportInfo transportInfo;
transportInfo.transport = Transport::CudaIpc;
cudaIpcMemHandle_t handle;
void* baseDataPtr;
size_t baseDataSize; // dummy
CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data));
CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr));
// TODO: bug with offset of base?
transportInfo.cudaIpcBaseHandle = handle;
transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr;
this->transportInfos.push_back(transportInfo);
}
if ((transports & AllIBTransports).any()) {
auto addIb = [&](Transport ibTransport) {
TransportInfo transportInfo;
transportInfo.transport = ibTransport;
const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size);
transportInfo.ibMr = mr;
transportInfo.ibLocal = true;
transportInfo.ibMrInfo = mr->getInfo();
this->transportInfos.push_back(transportInfo);
INFO(MSCCLPP_NET, "IB mr for address %p with size %ld is registered", data, size);
};
if (transports.has(Transport::IB0))
addIb(Transport::IB0);
if (transports.has(Transport::IB1))
addIb(Transport::IB1);
if (transports.has(Transport::IB2))
addIb(Transport::IB2);
if (transports.has(Transport::IB3))
addIb(Transport::IB3);
if (transports.has(Transport::IB4))
addIb(Transport::IB4);
if (transports.has(Transport::IB5))
addIb(Transport::IB5);
if (transports.has(Transport::IB6))
addIb(Transport::IB6);
if (transports.has(Transport::IB7))
addIb(Transport::IB7);
}
}
MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : pimpl(pimpl)
{
}
MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
MSCCLPP_API_CPP void* RegisteredMemory::data()
{
return pimpl->data;
}
MSCCLPP_API_CPP size_t RegisteredMemory::size()
{
return pimpl->size;
}
MSCCLPP_API_CPP int RegisteredMemory::rank()
{
return pimpl->rank;
}
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports()
{
return pimpl->transports;
}
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
{
std::vector<char> result;
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
throw std::runtime_error("Too many transport info entries");
}
int8_t transportCount = pimpl->transportInfos.size();
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
for (auto& entry : pimpl->transportInfos) {
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
if (entry.transport == Transport::CudaIpc) {
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase),
std::back_inserter(result));
} else if (AllIBTransports.has(entry.transport)) {
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
} else {
throw std::runtime_error("Unknown transport");
}
}
return result;
}
MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data)
{
return RegisteredMemory(std::make_shared<Impl>(data));
}
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
{
auto it = serialization.begin();
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
it += sizeof(this->size);
std::copy_n(it, sizeof(this->rank), reinterpret_cast<char*>(&this->rank));
it += sizeof(this->rank);
std::copy_n(it, sizeof(this->hostHash), reinterpret_cast<char*>(&this->hostHash));
it += sizeof(this->hostHash);
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
it += sizeof(this->transports);
int8_t transportCount;
std::copy_n(it, sizeof(transportCount), reinterpret_cast<char*>(&transportCount));
it += sizeof(transportCount);
for (int i = 0; i < transportCount; ++i) {
TransportInfo transportInfo;
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
it += sizeof(transportInfo.transport);
if (transportInfo.transport == Transport::CudaIpc) {
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
it += sizeof(transportInfo.cudaIpcBaseHandle);
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
} else if (AllIBTransports.has(transportInfo.transport)) {
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
it += sizeof(transportInfo.ibMrInfo);
transportInfo.ibLocal = false;
} else {
throw std::runtime_error("Unknown transport");
}
this->transportInfos.push_back(transportInfo);
}
if (it != serialization.end()) {
throw std::runtime_error("Deserialization failed");
}
if (transports.has(Transport::CudaIpc)) {
uint64_t localHostHash = getHostHash();
if (localHostHash == this->hostHash) {
auto entry = getTransportInfo(Transport::CudaIpc);
void* base;
CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data);
}
}
}
} // namespace mscclpp

View File

@@ -9,6 +9,7 @@
#include <numa.h>
#include <stdlib.h>
#include <string>
#include <memory>
// Get current Compute Capability
// int mscclppCudaCompCap() {
@@ -112,7 +113,7 @@ uint64_t getHash(const char* string, int n)
* This string can be overridden by using the MSCCLPP_HOSTID env var.
*/
#define HOSTID_FILE "/proc/sys/kernel/random/boot_id"
uint64_t getHostHash(void)
uint64_t computeHostHash(void)
{
char hostHash[1024];
char* hostId;
@@ -144,6 +145,12 @@ uint64_t getHostHash(void)
return getHash(hostHash, strlen(hostHash));
}
uint64_t getHostHash(void)
{
thread_local std::unique_ptr<uint64_t> hostHash = std::make_unique<uint64_t>(computeHostHash());
return *hostHash;
}
/* Generate a hash of the unique identifying string for this process
* that will be unique for both bare-metal and container instances
* Equivalent of a hash of;

10
tests/CMakeLists.txt Normal file
View File

@@ -0,0 +1,10 @@
add_executable(bootstrap_test_cpp bootstrap_test_cpp.cc)
target_link_libraries(bootstrap_test_cpp mscclpp MPI::MPI_CXX)
add_executable(communicator_test_cpp communicator_test_cpp.cu)
target_link_libraries(communicator_test_cpp mscclpp MPI::MPI_CXX)
add_executable(allgather_test_cpp allgather_test_cpp.cu)
target_link_libraries(allgather_test_cpp mscclpp MPI::MPI_CXX)
add_subdirectory(unittests)

View File

@@ -1,17 +1,18 @@
#include "mscclpp.h"
#include "mscclpp.hpp"
#include "channel.hpp"
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
#include "mpi.h"
#endif // MSCCLPP_USE_MPI_FOR_TESTS
#include <algorithm>
#include <cassert>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <unistd.h>
#include <unordered_map>
#include <cassert>
#include <algorithm>
static int nranksPerNode = 8;
@@ -48,29 +49,30 @@ static double getTime(void)
return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec;
}
__constant__ mscclpp::SimpleDeviceConnection constDevConns[16];
__constant__ mscclpp::channel::SimpleDeviceChannel constDevChans[16];
__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
__device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int remoteRank,
size_t nelemsPerGPU)
{
// this allgather is really simple and implemented as an alltoall
// this thread's role is a sender role
// put your data asynchronously
if ((threadIdx.x % 32) == 0)
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
// make sure everyone is put their data before some thread randomly blocks everyone else in signal
__syncthreads();
// push with flag and sync to make sure the data is received
if ((threadIdx.x % 32) == 0)
devConn.flush();
devChan.flush();
// this thread's role is a receiver role. wait on the semaphore to make sure the data is ready
if ((threadIdx.x % 32) == 0)
devConn.wait();
devChan.wait();
}
__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
uint64_t offset, uint64_t size)
__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
int remoteRank, uint64_t offset, uint64_t size)
{
// this allgather algorithm works as follows:
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
@@ -82,26 +84,29 @@ __device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
// put your data to GPU (rank+i) % nranksPerNode and signal in one call
if ((threadIdx.x % 32) == 0)
devConn.putWithSignalAndFlush(offset, size);
devChan.putWithSignal(offset, size);
}
// wait for the data from GPU (rank-i) % nranksPerNode to arrive
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
if ((threadIdx.x % 32) == 0)
devConn.wait();
devChan.wait();
}
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory");
}
}
__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
size_t nelemsPerGPU)
__device__ void allgather1(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
int remoteRank, size_t nelemsPerGPU)
{
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
nelemsPerGPU * sizeof(int));
if (remoteRank / nranksPerNode == rank / nranksPerNode)
if ((threadIdx.x % 32) == 0)
devChan.flush();
}
__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
size_t nelemsPerGPU)
__device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
int remoteRank, size_t nelemsPerGPU)
{
// this allgather is a pipelined and hierarchical one and only works for two nodes
// it is implemented as follows:
@@ -118,17 +123,17 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
// Step 1
// local allgather
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
nelemsPerGPU * sizeof(int));
}
// cross-node exchange
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
// opposite side
if ((threadIdx.x % 32) == 0)
devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int),
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int),
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0)
devConn.wait();
devChan.wait();
}
__syncthreads();
@@ -137,7 +142,7 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
// local allgather
int otherNghr = (rank + nranksPerNode) % world_size;
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
}
@@ -145,11 +150,11 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
// opposite side
if ((threadIdx.x % 32) == 0)
devConn.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
sizeof(int),
nelemsPerGPU / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0)
devConn.wait();
devChan.wait();
}
__syncthreads();
@@ -157,26 +162,31 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
// Step 3
// local allgather
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank,
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank,
(otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
nelemsPerGPU / pipelineSize * sizeof(int));
}
if (remoteRank / nranksPerNode == rank / nranksPerNode || remoteRank % nranksPerNode == rank % nranksPerNode) {
if ((threadIdx.x % 32) == 0)
devChan.flush();
}
}
__global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel)
{
// find the mapping between remoteRank and devConns
// find the mapping between remoteRank and devChans
int warpId = threadIdx.x / 32;
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
// Each warp is responsible for one of the remote ranks
mscclpp::SimpleDeviceConnection devConn = constDevConns[warpId];
mscclpp::channel::SimpleDeviceChannel devChan = constDevChans[warpId];
if (kernel == 0)
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
allgather0(devChan, rank, world_size, remoteRank, nelemsPerGPU);
else if (kernel == 1)
allgather1(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
allgather1(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
else if (kernel == 2)
allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
allgather2(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
}
int rankToLocalRank(int rank)
@@ -216,40 +226,44 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz
CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice));
}
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, int* data_d, size_t dataSize)
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize)
{
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
std::vector<std::shared_ptr<mscclpp::HostConnection>> hostConns;
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
std::vector<mscclpp::channel::ChannelId> channelIds;
std::vector<mscclpp::RegisteredMemory> localMemories;
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
for (int r = 0; r < world_size; ++r) {
if (r == rank)
continue;
mscclpp::TransportType transportType;
const char* ibDev = ibDevStr.c_str();
mscclpp::Transport transport;
if (rankToNode(r) == thisNode) {
ibDev = NULL;
transportType = mscclpp::TransportType::P2P;
transport = mscclpp::Transport::CudaIpc;
} else {
transportType = mscclpp::TransportType::IB;
transport = ibTransport;
}
// Connect with all other ranks
auto hostConn = comm.connect(r, 0, transportType, ibDev);
hostConn->registerBuffer(data_d, dataSize);
hostConns.push_back(hostConn);
channelIds.push_back(channelService.addChannel(comm.connectOnSetup(r, 0, transport)));
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
localMemories.push_back(memory);
comm.sendMemoryOnSetup(memory, r, 0);
remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0));
}
comm.connectionSetup();
comm.setup();
std::vector<mscclpp::SimpleDeviceConnection> devConns;
std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns),
[](std::shared_ptr<mscclpp::HostConnection>& hostConn) {
return mscclpp::SimpleDeviceConnection(*hostConn);
});
std::vector<mscclpp::channel::SimpleDeviceChannel> devChannels;
for (size_t i = 0; i < channelIds.size(); ++i) {
devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(channelService.deviceChannel(channelIds[i]),
channelService.addMemory(remoteMemories[i].get()), channelService.addMemory(localMemories[i])));
}
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection));
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() ));
assert(devChannels.size() < sizeof(constDevChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
CUDACHECK(
cudaMemcpyToSymbol(constDevChans, devChannels.data(), sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size()));
}
void printUsage(const char* prog, bool isMpi)
@@ -399,22 +413,25 @@ int main(int argc, const char* argv[])
}
size_t nelemsPerGPU = dataSize / sizeof(int) / world_size;
try{
try {
if (rank == 0)
printf("Initializing MSCCL++\n");
mscclpp::Communicator comm(world_size, ip_port, rank);
printf("Initializing MSCCL++\n");
auto bootstrapper = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
bootstrapper->initialize(ip_port);
mscclpp::Communicator comm(bootstrapper);
mscclpp::channel::DeviceChannelService channelService(comm);
if (rank == 0)
printf("Initializing data for allgather test\n");
printf("Initializing data for allgather test\n");
initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d);
if (rank == 0)
printf("Setting up the connection in MSCCL++\n");
setupMscclppConnections(rank, world_size, comm, data_d, dataSize);
printf("Setting up the connection in MSCCL++\n");
setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize);
if (rank == 0)
printf("Launching MSCCL++ proxy threads\n");
comm.startProxying();
channelService.startProxy();
if (rank == 0)
printf("Testing the correctness of AllGather implementation\n");
@@ -434,7 +451,7 @@ int main(int argc, const char* argv[])
}
int tmp[16];
// A simple barrier
comm.bootstrapAllGather(tmp, sizeof(int));
bootstrapper->allGather(tmp, sizeof(int));
if (rank == 0)
printf("Successfully checked the correctness\n");
@@ -443,12 +460,12 @@ int main(int argc, const char* argv[])
if (rank == 0)
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
CUDACHECK(cudaStreamSynchronize(stream));
comm.bootstrapAllGather(tmp, sizeof(int));
bootstrapper->allGather(tmp, sizeof(int));
for (int i = 0; i < iterwithoutcudagraph; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
CUDACHECK(cudaStreamSynchronize(stream));
comm.bootstrapAllGather(tmp, sizeof(int));
bootstrapper->allGather(tmp, sizeof(int));
// cudaGraph Capture
int cudagraphiter = 10;
@@ -466,7 +483,7 @@ int main(int argc, const char* argv[])
int cudagraphwarmup = 10;
if (rank == 0)
printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup,
cudagraphiter);
cudagraphiter);
for (int i = 0; i < cudagraphwarmup; ++i) {
cudaGraphLaunch(instance, stream);
}
@@ -476,8 +493,8 @@ int main(int argc, const char* argv[])
int cudagraphlaunch = 10;
if (rank == 0)
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
cudagraphiter);
comm.bootstrapAllGather(tmp, sizeof(int));
cudagraphiter);
bootstrapper->allGather(tmp, sizeof(int));
double t0, t1, ms, time_in_us;
t0 = getTime();
for (int i = 0; i < cudagraphlaunch; ++i) {
@@ -489,12 +506,12 @@ int main(int argc, const char* argv[])
ms = (t1 - t0) * 1000.0;
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
comm.bootstrapAllGather(tmp, sizeof(int));
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
bootstrapper->allGather(tmp, sizeof(int));
if (rank == 0)
printf("Stopping MSCCL++ proxy threads\n");
comm.stopProxying();
channelService.stopProxy();
} catch (std::exception& e) {
// todo: throw exceptions in the implementation and process them here

View File

@@ -1,11 +1,12 @@
#include "mscclpp.hpp"
#include <memory>
#include <cassert>
#include <iostream>
#include <memory>
#include <mpi.h>
void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
std::vector<int> tmp(bootstrap->getNranks(), 0);
tmp[bootstrap->getRank()] = bootstrap->getRank() + 1;
bootstrap->allGather(tmp.data(), sizeof(int));
@@ -16,15 +17,17 @@ void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
std::cout << "AllGather test passed!" << std::endl;
}
void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "Barrier test passed!" << std::endl;
}
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (bootstrap->getRank() == 0)
if (bootstrap->getRank() == i)
continue;
int msg1 = (bootstrap->getRank() + 1) * 3;
int msg2 = (bootstrap->getRank() + 1) * 3 + 1;
@@ -35,7 +38,7 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
}
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (i == bootstrap->getRank())
if (bootstrap->getRank() == i)
continue;
int msg1 = 0;
int msg2 = 0;
@@ -52,14 +55,16 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
std::cout << "Send/Recv test passed!" << std::endl;
}
void test_all(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
void test_all(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
test_allgather(bootstrap);
test_barrier(bootstrap);
// test_sendrecv(bootstrap);
test_sendrecv(bootstrap);
}
void test_mscclpp_bootstrap_with_id(int rank, int worldSize){
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
void test_mscclpp_bootstrap_with_id(int rank, int worldSize)
{
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0)
id = bootstrap->createUniqueId();
@@ -71,7 +76,8 @@ void test_mscclpp_bootstrap_with_id(int rank, int worldSize){
std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl;
}
void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar){
void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar)
{
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
bootstrap->initialize(ipPortPiar);
@@ -80,47 +86,57 @@ void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipP
std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl;
}
class MPIBootstrap : public mscclpp::BaseBootstrap {
class MPIBootstrap : public mscclpp::BaseBootstrap
{
public:
MPIBootstrap() : BaseBootstrap() {}
int getRank() override {
MPIBootstrap() : BaseBootstrap()
{
}
int getRank() override
{
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}
int getNranks() override {
int getNranks() override
{
int worldSize;
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
return worldSize;
}
void allGather(void *sendbuf, int size) override {
void allGather(void* sendbuf, int size) override
{
MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD);
}
void barrier() override {
void barrier() override
{
MPI_Barrier(MPI_COMM_WORLD);
}
void send(void *sendbuf, int size, int dest, int tag) override {
void send(void* sendbuf, int size, int dest, int tag) override
{
MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
}
void recv(void *recvbuf, int size, int source, int tag) override {
void recv(void* recvbuf, int size, int source, int tag) override
{
MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
};
void test_mpi_bootstrap(){
void test_mpi_bootstrap()
{
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap(new MPIBootstrap());
test_all(bootstrap);
if (bootstrap->getRank() == 0)
std::cout << "--- MPI Bootstrap test passed! ---" << std::endl;
}
int main(int argc, char **argv)
int main(int argc, char** argv)
{
int rank, worldSize;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
if (argc > 2){
if (argc > 2) {
if (rank == 0)
std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl;
MPI_Finalize();

View File

@@ -0,0 +1,272 @@
#include "mscclpp.hpp"
#include "epoch.hpp"
#include <cassert>
#include <cuda_runtime.h>
#include <iostream>
#include <memory>
#include <mpi.h>
#include <unordered_map>
#define CUDATHROW(cmd) \
do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \
} \
} while (false)
mscclpp::Transport findIb(int localRank)
{
mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2,
mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
mscclpp::Transport::IB6, mscclpp::Transport::IB7};
return IBs[localRank];
}
void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemory){
localMemory = communicator.registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice);
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemory;
for (int i = 0; i < worldSize; i++) {
if (i != rank){
communicator.sendMemoryOnSetup(localMemory, i, 0);
futureRemoteMemory[i] = communicator.recvMemoryOnSetup(i, 0);
}
}
communicator.setup();
for (int i = 0; i < worldSize; i++) {
if (i != rank){
remoteMemory[i] = futureRemoteMemory[i].get();
}
}
}
void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections){
for (int i = 0; i < worldSize; i++) {
if (i != rank){
if (i / nRanksPerNode == rank / nRanksPerNode) {
connections[i] = communicator.connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
} else {
connections[i] = communicator.connectOnSetup(i, 0, myIbDevice);
}
}
}
communicator.setup();
}
void write_remote(int rank, int worldSize, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank){
for (int i = 0; i < worldSize; i++) {
if (i != rank) {
auto& conn = connections.at(i);
auto& peerMemory = remoteRegisteredMemories.at(i);
conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, rank * dataCountPerRank*sizeof(int), dataCountPerRank*sizeof(int));
conn->flush();
}
}
}
void device_buffer_init(int rank, int worldSize, int dataCount, std::vector<int*>& devicePtr){
for (int n = 0; n < (int)devicePtr.size(); n++){
std::vector<int> hostBuffer(dataCount, 0);
for (int i = 0; i < dataCount; i++) {
hostBuffer[i] = rank + n * worldSize;
}
CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount*sizeof(int), cudaMemcpyHostToDevice));
}
CUDATHROW(cudaDeviceSynchronize());
}
bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector<int*>& devicePtr){
for (int n = 0; n < (int)devicePtr.size(); n++){
std::vector<int> hostBuffer(dataCount, 0);
CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount*sizeof(int), cudaMemcpyDeviceToHost));
for (int i = 0; i < worldSize; i++) {
for (int j = i*dataCount/worldSize; j < (i+1)*dataCount/worldSize; j++) {
if (hostBuffer[j] != i + n * worldSize) {
return false;
}
}
}
}
return true;
}
void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, int numBuffers){
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
size_t dataCount = deviceBufferSize / sizeof(int);
device_buffer_init(rank, worldSize, dataCount, devicePtr);
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "CUDA memory initialization passed" << std::endl;
for (int n = 0; n < numBuffers; n++){
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
}
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "RDMA write for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
// polling until it becomes ready
bool ready = false;
int niter = 0;
do {
ready = test_device_buffer_write_correctness(worldSize, dataCount, devicePtr);
niter++;
if (niter == 10000){
throw std::runtime_error("Polling is stuck.");
}
} while (!ready);
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
}
__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
int tid = threadIdx.x;
if (tid != rank && tid < worldSize){
deviceEpochs[tid].epochIncrement();
}
}
__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
int tid = threadIdx.x;
if (tid != rank && tid < worldSize){
deviceEpochs[tid].wait();
}
}
void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs, int numBuffers){
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
size_t dataCount = deviceBufferSize / sizeof(int);
device_buffer_init(rank, worldSize, dataCount, devicePtr);
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "CUDA memory initialization passed" << std::endl;
mscclpp::DeviceEpoch* deviceEpochs;
CUDATHROW(cudaMalloc(&deviceEpochs, sizeof(mscclpp::DeviceEpoch) * worldSize));
for (int i = 0; i < worldSize; i++){
if (i != rank){
mscclpp::DeviceEpoch deviceEpoch = epochs[i]->deviceEpoch();
CUDATHROW(cudaMemcpy(&deviceEpochs[i], &deviceEpoch, sizeof(mscclpp::DeviceEpoch), cudaMemcpyHostToDevice));
}
}
CUDATHROW(cudaDeviceSynchronize());
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "CUDA device epochs are created" << std::endl;
for (int n = 0; n < numBuffers; n++){
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
}
increament_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
CUDATHROW(cudaDeviceSynchronize());
for (int i = 0; i < worldSize; i++){
if (i != rank){
epochs[i]->signal();
}
}
wait_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
CUDATHROW(cudaDeviceSynchronize());
if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)){
throw std::runtime_error("unexpected result.");
}
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---" << std::endl;
}
void test_communicator(int rank, int worldSize, int nranksPerNode)
{
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0)
id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->initialize(id);
mscclpp::Communicator communicator(bootstrap);
if (bootstrap->getRank() == 0)
std::cout << "Communicator initialization passed" << std::endl;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> connections;
auto myIbDevice = findIb(rank % nranksPerNode);
make_connections(communicator, rank, worldSize, nranksPerNode, myIbDevice, connections);
if (bootstrap->getRank() == 0)
std::cout << "Connection setup passed" << std::endl;
int numBuffers = 10;
std::vector<int*> devicePtr(numBuffers);
int deviceBufferSize = 1024*1024;
std::vector<mscclpp::RegisteredMemory> localMemory(numBuffers);
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remoteMemory(numBuffers);
for (int n = 0; n < numBuffers; n++) {
if (n % 100 == 0)
std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl;
CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize));
register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]);
}
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, numBuffers);
if (bootstrap->getRank() == 0)
std::cout << "--- Testing vanialla writes passed ---" << std::endl;
std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs;
for (auto entry : connections) {
auto& conn = entry.second;
epochs.insert({entry.first, std::make_shared<mscclpp::Epoch>(communicator, conn)});
}
communicator.setup();
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "Epochs are created" << std::endl;
test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, epochs, numBuffers);
if (bootstrap->getRank() == 0)
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
for (int n = 0; n < numBuffers; n++){
CUDATHROW(cudaFree(devicePtr[n]));
}
}
int main(int argc, char** argv)
{
int rank, worldSize;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
MPI_Comm shmcomm;
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
int shmWorldSize;
MPI_Comm_size(shmcomm, &shmWorldSize);
int nranksPerNode = shmWorldSize;
MPI_Comm_free(&shmcomm);
test_communicator(rank, worldSize, nranksPerNode);
MPI_Finalize();
return 0;
}

View File

@@ -0,0 +1,2 @@
add_executable(ib_test ib_test.cc)
target_link_libraries(ib_test mscclpp MPI::MPI_CXX CUDA::cudart)

View File

@@ -1,7 +1,9 @@
#include "alloc.h"
#include "checks.h"
#include "ib.h"
#include <set>
#include "ib.hpp"
#include "infiniband/verbs.h"
#include "mscclpp.hpp"
#include <array>
#include <string>
// Measure current time in second.
@@ -24,8 +26,8 @@ int main(int argc, const char* argv[])
printf("Usage: %s <ip:port> <0(recv)/1(send)> <gpu id> <ib id>\n", argv[0]);
return 1;
}
const char* ip_port = argv[1];
int is_send = atoi(argv[2]);
const char* ipPortPair = argv[1];
int isSend = atoi(argv[2]);
int cudaDevId = atoi(argv[3]);
std::string ibDevName = "mlx5_ib" + std::string(argv[4]);
@@ -35,51 +37,40 @@ int main(int argc, const char* argv[])
int nelem = 1;
MSCCLPPCHECK(mscclppCudaCalloc(&data, nelem));
mscclppComm_t comm;
MSCCLPPCHECK(mscclppCommInitRank(&comm, 2, ip_port, is_send));
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(isSend, 2));
bootstrap->initialize(ipPortPair);
struct mscclppIbContext* ctx;
struct mscclppIbQp* qp;
struct mscclppIbMr* mr;
MSCCLPPCHECK(mscclppIbContextCreate(&ctx, ibDevName.c_str()));
MSCCLPPCHECK(mscclppIbContextCreateQp(ctx, &qp));
MSCCLPPCHECK(mscclppIbContextRegisterMr(ctx, data, sizeof(int) * nelem, &mr));
mscclpp::IbCtx ctx(ibDevName);
mscclpp::IbQp* qp = ctx.createQp();
const mscclpp::IbMr* mr = ctx.registerMr(data, sizeof(int) * nelem);
struct mscclppIbQpInfo* qpInfo;
MSCCLPPCHECK(mscclppCalloc(&qpInfo, 2));
qpInfo[is_send] = qp->info;
std::array<mscclpp::IbQpInfo, 2> qpInfo;
qpInfo[isSend] = qp->getInfo();
struct mscclppIbMrInfo* mrInfo;
MSCCLPPCHECK(mscclppCalloc(&mrInfo, 2));
mrInfo[is_send] = mr->info;
std::array<mscclpp::IbMrInfo, 2> mrInfo;
mrInfo[isSend] = mr->getInfo();
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, qpInfo, sizeof(struct mscclppIbQpInfo)));
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, mrInfo, sizeof(struct mscclppIbMrInfo)));
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo));
for (int i = 0; i < 2; ++i) {
if (i == is_send)
for (int i = 0; i < bootstrap->getNranks(); ++i) {
if (i == isSend)
continue;
qp->rtr(&qpInfo[i]);
qp->rtr(qpInfo[i]);
qp->rts();
break;
}
printf("connection succeed\n");
// A simple barrier
int* tmp;
MSCCLPPCHECK(mscclppCalloc(&tmp, 2));
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
bootstrap->barrier();
if (is_send) {
if (isSend) {
int maxIter = 100000;
double start = getTime();
for (int iter = 0; iter < maxIter; ++iter) {
qp->stageSend(mr, &mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true);
if (qp->postSend() != 0) {
WARN("postSend failed");
return 1;
}
qp->stageSend(mr, mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true);
qp->postSend();
bool waiting = true;
while (waiting) {
int wcNum = qp->pollCq();
@@ -88,7 +79,7 @@ int main(int argc, const char* argv[])
return 1;
}
for (int i = 0; i < wcNum; ++i) {
struct ibv_wc* wc = &qp->wcs[i];
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
if (wc->status != IBV_WC_SUCCESS) {
WARN("wc status %d", wc->status);
return 1;
@@ -103,10 +94,7 @@ int main(int argc, const char* argv[])
}
// A simple barrier
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
MSCCLPPCHECK(mscclppIbContextDestroy(ctx));
MSCCLPPCHECK(mscclppCommDestroy(comm));
bootstrap->barrier();
return 0;
}