mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Merge pull request #66 from microsoft/olli/api-extension
Olli/api extension
This commit is contained in:
32
CMakeLists.txt
Normal file
32
CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
cmake_minimum_required(VERSION 3.26)
|
||||
|
||||
project(mscclpp LANGUAGES CUDA CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules)
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
find_package(IBVerbs REQUIRED)
|
||||
find_package(NUMA REQUIRED)
|
||||
find_package(GDRCopy)
|
||||
|
||||
option(USE_MPI_FOR_TESTS "Use MPI for tests" ON)
|
||||
if(USE_MPI_FOR_TESTS)
|
||||
find_package(MPI REQUIRED)
|
||||
add_definitions(-DMSCCLPP_USE_MPI_FOR_TESTS)
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
|
||||
add_library(mscclpp SHARED)
|
||||
add_subdirectory(src) # This adds the srouces to the mscclpp target
|
||||
target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src/include)
|
||||
set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart CUDA::cuda_driver)
|
||||
if(GDRCOPY_FOUND)
|
||||
target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy)
|
||||
endif()
|
||||
|
||||
add_subdirectory(tests)
|
||||
9
Makefile
9
Makefile
@@ -61,7 +61,7 @@ endif
|
||||
|
||||
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xfatbin -compress-all
|
||||
# Use addprefix so that we can specify more than one path
|
||||
NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt
|
||||
NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt -lcuda
|
||||
|
||||
ifeq ($(DEBUG), 0)
|
||||
NVCUFLAGS += -O3
|
||||
@@ -120,7 +120,8 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
|
||||
|
||||
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
|
||||
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
|
||||
LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc)
|
||||
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
|
||||
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc)
|
||||
ifneq ($(NPKIT), 0)
|
||||
LIBSRCS += $(addprefix src/misc/,npkit.cc)
|
||||
endif
|
||||
@@ -134,7 +135,7 @@ HEADERS := $(wildcard src/include/*.h)
|
||||
CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
|
||||
PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
|
||||
|
||||
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp
|
||||
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp
|
||||
INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%)
|
||||
|
||||
LIBNAME := libmscclpp.so
|
||||
@@ -148,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
|
||||
|
||||
TESTSDIR := tests
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu bootstrap_test_cpp.cc)
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu communicator_test_cpp.cu bootstrap_test_cpp.cc allgather_test_cpp.cu)
|
||||
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
|
||||
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))
|
||||
|
||||
8
TODO.md
Normal file
8
TODO.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Core API extraction
|
||||
|
||||
- Add a test for host side Communicator/RegisteredMemory/Connection use.
|
||||
- Implement a standalone "epoch" synchronization construct that can be used as a component in custom proxies. epoch.hpp/cc has the beginnings of this.
|
||||
- Reimplement the "standard" proxy service + DeviceConnection on top of the new Communicator/RegisteredMemory/Connection core API. Remants of the old code is in channel.hpp, basic_proxy_handler.hpp/cc and host_connection.hpp/cc. Probably need a manager class to wrap all of this.
|
||||
- Change the new IBConnection and Communicator to use the new C++ IbCtx and IbQp classes.
|
||||
- Implement IbQp::~IbQp()
|
||||
- Fix RegisteredMemory::Impl::Impl to get the IPC handle from the base pointer, not the derived pointer.
|
||||
41
cmake/modules/FindGDRCopy.cmake
Normal file
41
cmake/modules/FindGDRCopy.cmake
Normal file
@@ -0,0 +1,41 @@
|
||||
# Find the GDRCopy libraries
|
||||
#
|
||||
# The following variables are optionally searched for defaults
|
||||
# GDRCOPY_ROOT_DIR: Base directory where all GDRCopy components are found
|
||||
# GDRCOPY_INCLUDE_DIR: Directory where GDRCopy headers are found
|
||||
# GDRCOPY_LIB_DIR: Directory where GDRCopy libraries are found
|
||||
|
||||
# The following are set after configuration is done:
|
||||
# GDRCOPY_FOUND
|
||||
# GDRCOPY_INCLUDE_DIRS
|
||||
# GDRCOPY_LIBRARIES
|
||||
|
||||
# An imported target MSCCLPP::gdrcopy is created if the library is found.
|
||||
|
||||
find_path(GDRCOPY_INCLUDE_DIRS
|
||||
NAMES gdrapi.h
|
||||
HINTS
|
||||
${GDRCOPY_INCLUDE_DIR}
|
||||
${GDRCOPY_ROOT_DIR}
|
||||
${GDRCOPY_ROOT_DIR}/include)
|
||||
|
||||
find_library(GDRCOPY_LIBRARIES
|
||||
NAMES gdrapi
|
||||
HINTS
|
||||
${GDRCOPY_LIB_DIR}
|
||||
${GDRCOPY_ROOT_DIR}
|
||||
${GDRCOPY_ROOT_DIR}/lib)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(GDRCopy DEFAULT_MSG GDRCOPY_INCLUDE_DIRS GDRCOPY_LIBRARIES)
|
||||
mark_as_advanced(GDRCOPY_INCLUDE_DIR GDRCOPY_LIBRARIES)
|
||||
|
||||
if(GDRCOPY_FOUND)
|
||||
if(NOT TARGET MSCCLPP::gdrcopy)
|
||||
add_library(MSCCLPP::gdrcopy UNKNOWN IMPORTED)
|
||||
endif()
|
||||
set_target_properties(MSCCLPP::gdrcopy PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${GDRCOPY_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
|
||||
IMPORTED_LOCATION "${GDRCOPY_LIBRARIES}")
|
||||
endif()
|
||||
41
cmake/modules/FindIBVerbs.cmake
Normal file
41
cmake/modules/FindIBVerbs.cmake
Normal file
@@ -0,0 +1,41 @@
|
||||
# Find the IB Verbs libraries
|
||||
#
|
||||
# The following variables are optionally searched for defaults
|
||||
# IBVERBS_ROOT_DIR: Base directory where all ibverbs components are found
|
||||
# IBVERBS_INCLUDE_DIR: Directory where ibverbs headers are found
|
||||
# IBVERBS_LIB_DIR: Directory where ibverbs libraries are found
|
||||
|
||||
# The following are set after configuration is done:
|
||||
# IBVERBS_FOUND
|
||||
# IBVERBS_INCLUDE_DIRS
|
||||
# IBVERBS_LIBRARIES
|
||||
|
||||
# An imported target MSCCLPP::ibverbs is created if the library is found.
|
||||
|
||||
find_path(IBVERBS_INCLUDE_DIRS
|
||||
NAMES infiniband/verbs.h
|
||||
HINTS
|
||||
${IBVERBS_INCLUDE_DIR}
|
||||
${IBVERBS_ROOT_DIR}
|
||||
${IBVERBS_ROOT_DIR}/include)
|
||||
|
||||
find_library(IBVERBS_LIBRARIES
|
||||
NAMES ibverbs
|
||||
HINTS
|
||||
${IBVERBS_LIB_DIR}
|
||||
${IBVERBS_ROOT_DIR}
|
||||
${IBVERBS_ROOT_DIR}/lib)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(IBVerbs DEFAULT_MSG IBVERBS_INCLUDE_DIRS IBVERBS_LIBRARIES)
|
||||
mark_as_advanced(IBVERBS_INCLUDE_DIR IBVERBS_LIBRARIES)
|
||||
|
||||
if(IBVERBS_FOUND)
|
||||
if(NOT TARGET MSCCLPP::ibverbs)
|
||||
add_library(MSCCLPP::ibverbs UNKNOWN IMPORTED)
|
||||
endif()
|
||||
set_target_properties(MSCCLPP::ibverbs PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${IBVERBS_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
|
||||
IMPORTED_LOCATION "${IBVERBS_LIBRARIES}")
|
||||
endif()
|
||||
41
cmake/modules/FindNUMA.cmake
Normal file
41
cmake/modules/FindNUMA.cmake
Normal file
@@ -0,0 +1,41 @@
|
||||
# Find the numa libraries
|
||||
#
|
||||
# The following variables are optionally searched for defaults
|
||||
# NUMA_ROOT_DIR: Base directory where all numa components are found
|
||||
# NUMA_INCLUDE_DIR: Directory where numa headers are found
|
||||
# NUMA_LIB_DIR: Directory where numa libraries are found
|
||||
|
||||
# The following are set after configuration is done:
|
||||
# NUMA_FOUND
|
||||
# NUMA_INCLUDE_DIRS
|
||||
# NUMA_LIBRARIES
|
||||
|
||||
# An imported target MSCCLPP::numa is created if the library is found.
|
||||
|
||||
find_path(NUMA_INCLUDE_DIRS
|
||||
NAMES numa.h
|
||||
HINTS
|
||||
${NUMA_INCLUDE_DIR}
|
||||
${NUMA_ROOT_DIR}
|
||||
${NUMA_ROOT_DIR}/include)
|
||||
|
||||
find_library(NUMA_LIBRARIES
|
||||
NAMES numa
|
||||
HINTS
|
||||
${NUMA_LIB_DIR}
|
||||
${NUMA_ROOT_DIR}
|
||||
${NUMA_ROOT_DIR}/lib)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(NUMA DEFAULT_MSG NUMA_INCLUDE_DIRS NUMA_LIBRARIES)
|
||||
mark_as_advanced(NUMA_INCLUDE_DIR NUMA_LIBRARIES)
|
||||
|
||||
if(NUMA_FOUND)
|
||||
if(NOT TARGET MSCCLPP::numa)
|
||||
add_library(MSCCLPP::numa UNKNOWN IMPORTED)
|
||||
endif()
|
||||
set_target_properties(MSCCLPP::numa PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
|
||||
IMPORTED_LOCATION "${NUMA_LIBRARIES}")
|
||||
endif()
|
||||
5
src/CMakeLists.txt
Normal file
5
src/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.h)
|
||||
file(GLOB to_remove gdr.cc)
|
||||
list(REMOVE_ITEM SOURCES ${to_remove})
|
||||
|
||||
target_sources(mscclpp PRIVATE ${SOURCES})
|
||||
@@ -1,29 +0,0 @@
|
||||
#include "basic_proxy_handler.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) {
|
||||
return [&comm](ProxyTrigger triggerRaw) {
|
||||
ChannelTrigger *trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
|
||||
HostConnection& conn = *comm.connections.at(trigger->fields.connId);
|
||||
|
||||
auto result = ProxyHandlerResult::Continue;
|
||||
|
||||
if (trigger->fields.type & mscclppData) {
|
||||
conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, trigger->fields.srcOffset, trigger->fields.size);
|
||||
}
|
||||
|
||||
if (trigger->fields.type & mscclppFlag) {
|
||||
conn.signal();
|
||||
}
|
||||
|
||||
if (trigger->fields.type & mscclppSync) {
|
||||
conn.flush();
|
||||
result = ProxyHandlerResult::FlushFifoTailAndContinue;
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -180,9 +180,8 @@ Bootstrap::Impl::~Impl()
|
||||
}
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
|
||||
std::vector<mscclppSocketAddress>& rankAddresses,
|
||||
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank)
|
||||
void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector<mscclppSocketAddress>& rankAddresses,
|
||||
std::vector<mscclppSocketAddress>& rankAddressesRoot, int& rank)
|
||||
{
|
||||
mscclppSocket sock;
|
||||
ExtInfo info;
|
||||
@@ -211,7 +210,7 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector<mscclppSocketAddress>& rankAddresses,
|
||||
const std::vector<mscclppSocketAddress>& rankAddressesRoot)
|
||||
const std::vector<mscclppSocketAddress>& rankAddressesRoot)
|
||||
{
|
||||
mscclppSocket sock;
|
||||
int next = (peer + 1) % this->nRanks_;
|
||||
@@ -226,7 +225,8 @@ void Bootstrap::Impl::bootstrapCreateRoot()
|
||||
mscclppSocket listenSock;
|
||||
|
||||
// mscclppSocket* listenSock = new mscclppSocket(); // TODO(saemal) make this a shared ptr
|
||||
MSCCLPPTHROW(mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
|
||||
MSCCLPPTHROW(
|
||||
mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
|
||||
MSCCLPPTHROW(mscclppSocketListen(&listenSock));
|
||||
MSCCLPPTHROW(mscclppSocketGetAddr(&listenSock, &uniqueId_.addr));
|
||||
auto lambda = [this, listenSock]() { this->bootstrapRoot(listenSock); };
|
||||
|
||||
26
src/channel.cc
Normal file
26
src/channel.cc
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "channel.hpp"
|
||||
#include "utils.h"
|
||||
#include "checks.hpp"
|
||||
#include "api.h"
|
||||
#include "debug.h"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace channel {
|
||||
|
||||
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator),
|
||||
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
|
||||
int cudaDevice;
|
||||
CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void DeviceChannelService::bindThread()
|
||||
{
|
||||
if (deviceNumaNode >= 0) {
|
||||
MSCCLPPTHROW(numaBind(deviceNumaNode));
|
||||
INFO(MSCCLPP_INIT, "NUMA node of DeviceChannelService proxy thread is set to %d", deviceNumaNode);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace channel
|
||||
} // namespace mscclpp
|
||||
@@ -1,81 +1,152 @@
|
||||
#include "communicator.hpp"
|
||||
#include "host_connection.hpp"
|
||||
#include "comm.h"
|
||||
#include "basic_proxy_handler.hpp"
|
||||
#include <sstream>
|
||||
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "comm.h"
|
||||
#include "communicator.hpp"
|
||||
#include "connection.hpp"
|
||||
#include "debug.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "registered_memory.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {}
|
||||
Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(bootstrap)
|
||||
{
|
||||
rankToHash_.resize(bootstrap->getNranks());
|
||||
auto hostHash = getHostHash();
|
||||
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
|
||||
rankToHash_[bootstrap->getRank()] = hostHash;
|
||||
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
|
||||
}
|
||||
|
||||
Communicator::Impl::~Impl() {
|
||||
if (comm) {
|
||||
mscclppCommDestroy(comm);
|
||||
Communicator::Impl::~Impl()
|
||||
{
|
||||
ibContexts_.clear();
|
||||
}
|
||||
|
||||
IbCtx* Communicator::Impl::getIbContext(Transport ibTransport)
|
||||
{
|
||||
// Find IB context or create it
|
||||
auto it = ibContexts_.find(ibTransport);
|
||||
if (it == ibContexts_.end()) {
|
||||
auto ibDev = getIBDeviceName(ibTransport);
|
||||
ibContexts_[ibTransport] = std::make_unique<IbCtx>(ibDev);
|
||||
return ibContexts_[ibTransport].get();
|
||||
} else {
|
||||
return it->second.get();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
mscclppTransport_t transportTypeToCStyle(TransportType type) {
|
||||
switch (type) {
|
||||
case TransportType::IB:
|
||||
return mscclppTransportIB;
|
||||
case TransportType::P2P:
|
||||
return mscclppTransportP2P;
|
||||
default:
|
||||
throw std::runtime_error("Unknown transport type");
|
||||
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
: pimpl(std::make_unique<Impl>(bootstrap))
|
||||
{
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<BaseBootstrap> Communicator::bootstrapper()
|
||||
{
|
||||
return pimpl->bootstrap_;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports)
|
||||
{
|
||||
return RegisteredMemory(
|
||||
std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl));
|
||||
}
|
||||
|
||||
struct MemorySender : public Setuppable
|
||||
{
|
||||
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
|
||||
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
|
||||
{
|
||||
bootstrap->send(memory_.serialize(), remoteRank_, tag_);
|
||||
}
|
||||
|
||||
RegisteredMemory memory_;
|
||||
int remoteRank_;
|
||||
int tag_;
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag)
|
||||
{
|
||||
addSetup(std::make_shared<MemorySender>(memory, remoteRank, tag));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
|
||||
struct MemoryReceiver : public Setuppable
|
||||
{
|
||||
MemoryReceiver(int remoteRank, int tag)
|
||||
: remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override
|
||||
{
|
||||
std::vector<char> data;
|
||||
bootstrap->recv(data, remoteRank_, tag_);
|
||||
memoryPromise_.set_value(RegisteredMemory::deserialize(data));
|
||||
}
|
||||
|
||||
std::promise<RegisteredMemory> memoryPromise_;
|
||||
int remoteRank_;
|
||||
int tag_;
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSetup(int remoteRank, int tag)
|
||||
{
|
||||
auto memoryReceiver = std::make_shared<MemoryReceiver>(remoteRank, tag);
|
||||
addSetup(memoryReceiver);
|
||||
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
|
||||
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
|
||||
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
|
||||
mscclppBootstrapAllGather(pimpl->comm, data, size);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapBarrier() {
|
||||
mscclppBootstrapBarrier(pimpl->comm);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
|
||||
TransportType transportType, const char* ibDev) {
|
||||
mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
|
||||
auto connIdx = pimpl->connections.size();
|
||||
auto conn = std::make_shared<HostConnection>(std::make_unique<HostConnection::Impl>(this, &pimpl->comm->conns[connIdx]));
|
||||
pimpl->connections.push_back(conn);
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport)
|
||||
{
|
||||
std::shared_ptr<ConnectionBase> conn;
|
||||
if (transport == Transport::CudaIpc) {
|
||||
// sanity check: make sure the IPC connection is being made within a node
|
||||
if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) {
|
||||
std::stringstream ss;
|
||||
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex
|
||||
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"
|
||||
<< " != " << pimpl->bootstrap_->getRank() << "(" << std::hex
|
||||
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag);
|
||||
conn = cudaIpcConn;
|
||||
INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created",
|
||||
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank,
|
||||
pimpl->rankToHash_[remoteRank]);
|
||||
} else if (AllIBTransports.has(transport)) {
|
||||
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
|
||||
conn = ibConn;
|
||||
INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created",
|
||||
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
|
||||
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported transport");
|
||||
}
|
||||
pimpl->connections_.push_back(conn);
|
||||
addSetup(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
mscclppConnectionSetup(pimpl->comm);
|
||||
MSCCLPP_API_CPP void Communicator::addSetup(std::shared_ptr<Setuppable> setuppable)
|
||||
{
|
||||
pimpl->toSetup_.push_back(setuppable);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::startProxying() {
|
||||
pimpl->proxy.start();
|
||||
MSCCLPP_API_CPP void Communicator::setup()
|
||||
{
|
||||
for (auto& setuppable : pimpl->toSetup_) {
|
||||
setuppable->beginSetup(pimpl->bootstrap_);
|
||||
}
|
||||
for (auto& setuppable : pimpl->toSetup_) {
|
||||
setuppable->endSetup(pimpl->bootstrap_);
|
||||
}
|
||||
pimpl->toSetup_.clear();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::stopProxying() {
|
||||
pimpl->proxy.stop();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::rank() {
|
||||
int result;
|
||||
mscclppCommRank(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::size() {
|
||||
int result;
|
||||
mscclppCommSize(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
165
src/connection.cc
Normal file
165
src/connection.cc
Normal file
@@ -0,0 +1,165 @@
|
||||
#include <algorithm>
|
||||
#include "connection.hpp"
|
||||
#include "checks.hpp"
|
||||
#include "infiniband/verbs.h"
|
||||
#include "npkit/npkit.h"
|
||||
#include "registered_memory.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
void validateTransport(RegisteredMemory mem, Transport transport)
|
||||
{
|
||||
if (!mem.transports().has(transport)) {
|
||||
throw std::runtime_error("mem does not support transport");
|
||||
}
|
||||
}
|
||||
|
||||
// Connection
|
||||
|
||||
std::shared_ptr<RegisteredMemory::Impl> Connection::getRegisteredMemoryImpl(RegisteredMemory& mem)
|
||||
{
|
||||
return mem.pimpl;
|
||||
}
|
||||
|
||||
// ConnectionBase
|
||||
|
||||
ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
|
||||
|
||||
int ConnectionBase::remoteRank() { return remoteRank_; }
|
||||
|
||||
int ConnectionBase::tag() { return tag_; }
|
||||
|
||||
// CudaIpcConnection
|
||||
|
||||
CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag)
|
||||
{
|
||||
CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
CudaIpcConnection::~CudaIpcConnection()
|
||||
{
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
Transport CudaIpcConnection::transport()
|
||||
{
|
||||
return Transport::CudaIpc;
|
||||
}
|
||||
|
||||
Transport CudaIpcConnection::remoteTransport()
|
||||
{
|
||||
return Transport::CudaIpc;
|
||||
}
|
||||
|
||||
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
{
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
char* dstPtr = (char*)dst.data();
|
||||
char* srcPtr = (char*)src.data();
|
||||
|
||||
CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream));
|
||||
INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size);
|
||||
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void CudaIpcConnection::flush()
|
||||
{
|
||||
CUDATHROW(cudaStreamSynchronize(stream));
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT);
|
||||
}
|
||||
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl)
|
||||
: ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0)
|
||||
{
|
||||
qp = commImpl.getIbContext(transport)->createQp();
|
||||
}
|
||||
|
||||
Transport IBConnection::transport()
|
||||
{
|
||||
return transport_;
|
||||
}
|
||||
|
||||
Transport IBConnection::remoteTransport()
|
||||
{
|
||||
return remoteTransport_;
|
||||
}
|
||||
|
||||
void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
{
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
|
||||
if (dstTransportInfo.ibLocal) {
|
||||
throw std::runtime_error("dst is local, which is not supported");
|
||||
}
|
||||
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport());
|
||||
if (!srcTransportInfo.ibLocal) {
|
||||
throw std::runtime_error("src is remote, which is not supported");
|
||||
}
|
||||
|
||||
auto dstMrInfo = dstTransportInfo.ibMrInfo;
|
||||
auto srcMr = srcTransportInfo.ibMr;
|
||||
|
||||
qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset,
|
||||
/*signaled=*/true);
|
||||
numSignaledSends++;
|
||||
qp->postSend();
|
||||
INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, (uint8_t*)dstMrInfo.addr + dstOffset, size);
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
|
||||
}
|
||||
|
||||
void IBConnection::flush()
|
||||
{
|
||||
while (numSignaledSends) {
|
||||
int wcNum = qp->pollCq();
|
||||
if (wcNum < 0) {
|
||||
throw std::runtime_error("pollCq failed: error no " + std::to_string(errno));
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status));
|
||||
}
|
||||
if (wc->opcode == IBV_WC_RDMA_WRITE) {
|
||||
numSignaledSends--;
|
||||
}
|
||||
}
|
||||
}
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
void IBConnection::beginSetup(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
{
|
||||
std::vector<char> ibQpTransport;
|
||||
std::copy_n(reinterpret_cast<char*>(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport));
|
||||
std::copy_n(reinterpret_cast<char*>(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport));
|
||||
|
||||
bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
|
||||
}
|
||||
|
||||
void IBConnection::endSetup(std::shared_ptr<BaseBootstrap> bootstrap)
|
||||
{
|
||||
std::vector<char> ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport));
|
||||
bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag());
|
||||
|
||||
IbQpInfo qpInfo;
|
||||
auto it = ibQpTransport.begin();
|
||||
std::copy_n(it, sizeof(qpInfo), reinterpret_cast<char*>(&qpInfo));
|
||||
it += sizeof(qpInfo);
|
||||
std::copy_n(it, sizeof(remoteTransport_), reinterpret_cast<char*>(&remoteTransport_));
|
||||
it += sizeof(qpInfo);
|
||||
|
||||
qp->rtr(qpInfo);
|
||||
qp->rts();
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
26
src/epoch.cc
Normal file
26
src/epoch.cc
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "epoch.hpp"
|
||||
#include "checks.hpp"
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection) : connection_(connection) {
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1));
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1));
|
||||
|
||||
localEpochIdsRegMem_ = communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport());
|
||||
communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection->remoteRank(), connection->tag());
|
||||
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Epoch::~Epoch() {
|
||||
mscclppCudaFree(device_.epochIds_);
|
||||
mscclppCudaFree(device_.expectedInboundEpochId_);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Epoch::signal() {
|
||||
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
38
src/fifo.cc
38
src/fifo.cc
@@ -1,13 +1,15 @@
|
||||
#include "mscclppfifo.hpp"
|
||||
#include "alloc.h"
|
||||
#include "checks.hpp"
|
||||
#include "mscclppfifo.hpp"
|
||||
#include "api.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdexcept>
|
||||
#include <emmintrin.h>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct HostProxyFifo::Impl {
|
||||
struct HostProxyFifo::Impl
|
||||
{
|
||||
DeviceProxyFifo deviceFifo;
|
||||
|
||||
// allocated on the host. Only accessed by the host. This is a copy of the
|
||||
@@ -23,7 +25,8 @@ struct HostProxyFifo::Impl {
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
HostProxyFifo::HostProxyFifo() {
|
||||
MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo()
|
||||
{
|
||||
pimpl = std::make_unique<Impl>();
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1));
|
||||
MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE));
|
||||
@@ -32,35 +35,40 @@ HostProxyFifo::HostProxyFifo() {
|
||||
pimpl->hostTail = 0;
|
||||
}
|
||||
|
||||
HostProxyFifo::~HostProxyFifo() {
|
||||
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.head));
|
||||
MSCCLPPTHROW(mscclppCudaHostFree(pimpl->deviceFifo.triggers));
|
||||
MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.tailReplica));
|
||||
CUDATHROW(cudaStreamDestroy(pimpl->stream));
|
||||
MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo()
|
||||
{
|
||||
mscclppCudaFree(pimpl->deviceFifo.head);
|
||||
mscclppCudaHostFree(pimpl->deviceFifo.triggers);
|
||||
mscclppCudaFree(pimpl->deviceFifo.tailReplica);
|
||||
cudaStreamDestroy(pimpl->stream);
|
||||
}
|
||||
|
||||
void HostProxyFifo::poll(ProxyTrigger *trigger) {
|
||||
MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger)
|
||||
{
|
||||
__m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
_mm_store_si128((__m128i*)trigger, xmm0);
|
||||
}
|
||||
|
||||
void HostProxyFifo::pop() {
|
||||
MSCCLPP_API_CPP void HostProxyFifo::pop()
|
||||
{
|
||||
*(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0;
|
||||
(pimpl->hostTail)++;
|
||||
}
|
||||
|
||||
void HostProxyFifo::flushTail(bool sync) {
|
||||
MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync)
|
||||
{
|
||||
// Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure
|
||||
// that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush
|
||||
// request.
|
||||
CUDATHROW(
|
||||
cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream));
|
||||
CUDATHROW(cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice,
|
||||
pimpl->stream));
|
||||
if (sync) {
|
||||
CUDATHROW(cudaStreamSynchronize(pimpl->stream));
|
||||
}
|
||||
}
|
||||
|
||||
DeviceProxyFifo HostProxyFifo::toDevice() {
|
||||
MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo()
|
||||
{
|
||||
return pimpl->deviceFifo;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
#include "host_connection.hpp"
|
||||
#include "communicator.hpp"
|
||||
#include "comm.h"
|
||||
#include "mscclpp.h"
|
||||
#include "mscclppfifo.h"
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) {
|
||||
this->hostConn = conn->hostConn;
|
||||
}
|
||||
|
||||
HostConnection::Impl::~Impl() {
|
||||
// TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not.
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostConnection::~HostConnection() = default;
|
||||
|
||||
MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr<Impl> p) : pimpl(std::move(p)) {}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::getId() {
|
||||
return pimpl->conn->connId;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) {
|
||||
BufferHandle result;
|
||||
static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t));
|
||||
mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, reinterpret_cast<mscclppBufferHandle_t*>(&result));
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::numLocalBuffers() {
|
||||
return pimpl->conn->bufferRegistrations.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) {
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() {
|
||||
return pimpl->conn->remoteBufferRegistrations.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) {
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() {
|
||||
ConnectionEpoch epoch;
|
||||
static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId));
|
||||
epoch.localSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->localSignalEpochId);
|
||||
epoch.remoteSignalEpochId = reinterpret_cast<SignalEpochId*>(pimpl->conn->devConn->remoteSignalEpochId);
|
||||
epoch.waitEpochId = pimpl->conn->devConn->waitEpochId;
|
||||
return epoch;
|
||||
}
|
||||
|
||||
|
||||
MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() {
|
||||
return pimpl->comm->pimpl->proxy.fifo().toDevice();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) {
|
||||
pimpl->hostConn->put(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::signal() {
|
||||
pimpl->hostConn->signal();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::flush() {
|
||||
pimpl->hostConn->flush();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void HostConnection::wait() {
|
||||
pimpl->hostConn->wait();
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
665
src/ib.cc
665
src/ib.cc
@@ -2,323 +2,182 @@
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <malloc.h>
|
||||
#include <sstream>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#include "alloc.h"
|
||||
#include "checks.hpp"
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include "api.h"
|
||||
#include <infiniband/verbs.h>
|
||||
#include <string>
|
||||
|
||||
static int getIbDevNumaNode(const char* ibDevPath)
|
||||
{
|
||||
if (ibDevPath == NULL) {
|
||||
WARN("ibDevPath is NULL");
|
||||
return -1;
|
||||
}
|
||||
const char* postfix = "/device/numa_node";
|
||||
FILE* fp = NULL;
|
||||
char* filePath = NULL;
|
||||
int node = -1;
|
||||
int res;
|
||||
if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) {
|
||||
WARN("mscclppCalloc failed");
|
||||
goto exit;
|
||||
}
|
||||
memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char));
|
||||
filePath[strlen(ibDevPath)] = '\0';
|
||||
if (strncat(filePath, postfix, strlen(postfix)) == NULL) {
|
||||
WARN("strncat failed");
|
||||
goto exit;
|
||||
}
|
||||
fp = fopen(filePath, "r");
|
||||
if (fp == NULL) {
|
||||
WARN("fopen failed (errno %d, path %s)", errno, filePath);
|
||||
goto exit;
|
||||
}
|
||||
res = fscanf(fp, "%d", &node);
|
||||
if (res != 1) {
|
||||
WARN("fscanf failed (errno %d, path %s)", errno, filePath);
|
||||
node = -1;
|
||||
goto exit;
|
||||
}
|
||||
exit:
|
||||
if (filePath != NULL) {
|
||||
free(filePath);
|
||||
}
|
||||
if (fp != NULL) {
|
||||
fclose(fp);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
namespace mscclpp {
|
||||
|
||||
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName)
|
||||
{
|
||||
struct mscclppIbContext* _ctx;
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ctx, 1));
|
||||
|
||||
std::vector<int> ports;
|
||||
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (strncmp(devices[i]->name, ibDevName, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
_ctx->ctx = ibv_open_device(devices[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ibv_free_device_list(devices);
|
||||
if (_ctx->ctx == nullptr) {
|
||||
WARN("ibv_open_device failed (errno %d, device name %s)", errno, ibDevName);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
// Check available ports
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(_ctx->ctx, &devAttr) != 0) {
|
||||
WARN("ibv_query_device failed (errno %d, device name %s)", errno, ibDevName);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
for (uint8_t i = 1; i <= devAttr.phys_port_cnt; ++i) {
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(_ctx->ctx, i, &portAttr) != 0) {
|
||||
WARN("ibv_query_port failed (errno %d, port %d)", errno, i);
|
||||
goto fail;
|
||||
}
|
||||
if (portAttr.state != IBV_PORT_ACTIVE) {
|
||||
continue;
|
||||
}
|
||||
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) {
|
||||
continue;
|
||||
}
|
||||
ports.push_back((int)i);
|
||||
}
|
||||
if (ports.size() == 0) {
|
||||
WARN("no active IB port found");
|
||||
goto fail;
|
||||
}
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ctx->ports, ports.size()));
|
||||
_ctx->nPorts = (int)ports.size();
|
||||
for (int i = 0; i < _ctx->nPorts; ++i) {
|
||||
_ctx->ports[i] = ports[i];
|
||||
}
|
||||
|
||||
_ctx->pd = ibv_alloc_pd(_ctx->ctx);
|
||||
if (_ctx->pd == NULL) {
|
||||
WARN("ibv_alloc_pd failed (errno %d)", errno);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
*ctx = _ctx;
|
||||
return mscclppSuccess;
|
||||
fail:
|
||||
*ctx = NULL;
|
||||
if (_ctx->ports != NULL) {
|
||||
free(_ctx->ports);
|
||||
}
|
||||
free(_ctx);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx)
|
||||
{
|
||||
for (int i = 0; i < ctx->nMrs; ++i) {
|
||||
if (ctx->mrs[i].mr) {
|
||||
ibv_dereg_mr(ctx->mrs[i].mr);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < ctx->nQps; ++i) {
|
||||
if (ctx->qps[i].qp) {
|
||||
ibv_destroy_qp(ctx->qps[i].qp);
|
||||
}
|
||||
ibv_destroy_cq(ctx->qps[i].cq);
|
||||
free(ctx->qps[i].wcs);
|
||||
free(ctx->qps[i].sges);
|
||||
free(ctx->qps[i].wrs);
|
||||
}
|
||||
if (ctx->pd != NULL) {
|
||||
ibv_dealloc_pd(ctx->pd);
|
||||
}
|
||||
if (ctx->ctx != NULL) {
|
||||
ibv_close_device(ctx->ctx);
|
||||
}
|
||||
free(ctx->mrs);
|
||||
free(ctx->qps);
|
||||
free(ctx->ports);
|
||||
free(ctx);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port /*=-1*/)
|
||||
{
|
||||
if (port < 0) {
|
||||
port = ctx->ports[0];
|
||||
} else {
|
||||
bool found = false;
|
||||
for (int i = 0; i < ctx->nPorts; ++i) {
|
||||
if (ctx->ports[i] == port) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
WARN("invalid IB port: %d", port);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
}
|
||||
|
||||
struct ibv_cq* cq = ibv_create_cq(ctx->ctx, MSCCLPP_IB_CQ_SIZE, NULL, NULL, 0);
|
||||
if (cq == NULL) {
|
||||
WARN("ibv_create_cq failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
struct ibv_qp_init_attr qp_init_attr;
|
||||
std::memset(&qp_init_attr, 0, sizeof(struct ibv_qp_init_attr));
|
||||
qp_init_attr.sq_sig_all = 0;
|
||||
qp_init_attr.send_cq = cq;
|
||||
qp_init_attr.recv_cq = cq;
|
||||
qp_init_attr.qp_type = IBV_QPT_RC;
|
||||
qp_init_attr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qp_init_attr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qp_init_attr.cap.max_send_sge = 1;
|
||||
qp_init_attr.cap.max_recv_sge = 1;
|
||||
qp_init_attr.cap.max_inline_data = 0;
|
||||
struct ibv_qp* qp = ibv_create_qp(ctx->pd, &qp_init_attr);
|
||||
if (qp == nullptr) {
|
||||
WARN("ibv_create_qp failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct ibv_port_attr port_attr;
|
||||
if (ibv_query_port(ctx->ctx, port, &port_attr) != 0) {
|
||||
WARN("ibv_query_port failed (errno %d, port %d)", errno, port);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
// Register QP to this ctx
|
||||
qp->context = ctx->ctx;
|
||||
if (qp->context == NULL) {
|
||||
WARN("IB context is NULL");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
ctx->nQps++;
|
||||
if (ctx->qps == NULL) {
|
||||
MSCCLPPCHECK(mscclppCalloc(&ctx->qps, MAXCONNECTIONS));
|
||||
ctx->maxQps = MAXCONNECTIONS;
|
||||
}
|
||||
if (ctx->maxQps < ctx->nQps) {
|
||||
WARN("too many QPs");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppIbQp* _ibQp = &ctx->qps[ctx->nQps - 1];
|
||||
_ibQp->qp = qp;
|
||||
_ibQp->info.lid = port_attr.lid;
|
||||
_ibQp->info.port = port;
|
||||
_ibQp->info.linkLayer = port_attr.link_layer;
|
||||
_ibQp->info.qpn = qp->qp_num;
|
||||
_ibQp->info.mtu = port_attr.active_mtu;
|
||||
if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND) {
|
||||
union ibv_gid gid;
|
||||
if (ibv_query_gid(ctx->ctx, port, 0, &gid) != 0) {
|
||||
WARN("ibv_query_gid failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
_ibQp->info.spn = gid.global.subnet_prefix;
|
||||
}
|
||||
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_INIT;
|
||||
qp_attr.pkey_index = 0;
|
||||
qp_attr.port_num = port;
|
||||
qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
|
||||
if (ibv_modify_qp(qp, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
WARN("ibv_modify_qp failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ibQp->wrs, MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ibQp->sges, MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPCHECK(mscclppCalloc(&_ibQp->wcs, MSCCLPP_IB_CQ_POLL_NUM));
|
||||
_ibQp->cq = cq;
|
||||
|
||||
*ibQp = _ibQp;
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size,
|
||||
struct mscclppIbMr** ibMr)
|
||||
IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff)
|
||||
{
|
||||
if (size == 0) {
|
||||
WARN("invalid size: %zu", size);
|
||||
return mscclppInvalidArgument;
|
||||
throw std::runtime_error("invalid size: " + std::to_string(size));
|
||||
}
|
||||
static __thread uintptr_t pageSize = 0;
|
||||
if (pageSize == 0) {
|
||||
pageSize = sysconf(_SC_PAGESIZE);
|
||||
}
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
|
||||
size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
struct ibv_mr* mr =
|
||||
ibv_reg_mr(ctx->pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
struct ibv_mr* _mr =
|
||||
ibv_reg_mr(_pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING);
|
||||
if (mr == nullptr) {
|
||||
WARN("ibv_reg_mr failed (errno %d)", errno);
|
||||
return mscclppInternalError;
|
||||
if (_mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_mr failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
ctx->nMrs++;
|
||||
if (ctx->mrs == NULL) {
|
||||
MSCCLPPCHECK(mscclppCalloc(&ctx->mrs, MAXCONNECTIONS));
|
||||
ctx->maxMrs = MAXCONNECTIONS;
|
||||
}
|
||||
if (ctx->maxMrs < ctx->nMrs) {
|
||||
WARN("too many MRs");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppIbMr* _ibMr = &ctx->mrs[ctx->nMrs - 1];
|
||||
_ibMr->mr = mr;
|
||||
_ibMr->buff = buff;
|
||||
_ibMr->info.addr = (uint64_t)buff;
|
||||
_ibMr->info.rkey = mr->rkey;
|
||||
*ibMr = _ibMr;
|
||||
return mscclppSuccess;
|
||||
this->mr = _mr;
|
||||
this->size = pages * pageSize;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
IbMr::~IbMr()
|
||||
{
|
||||
ibv_dereg_mr(reinterpret_cast<struct ibv_mr*>(this->mr));
|
||||
}
|
||||
|
||||
int mscclppIbQp::rtr(const mscclppIbQpInfo* info)
|
||||
IbMrInfo IbMr::getInfo() const
|
||||
{
|
||||
IbMrInfo info;
|
||||
info.addr = reinterpret_cast<uint64_t>(this->buff);
|
||||
info.rkey = reinterpret_cast<struct ibv_mr*>(this->mr)->rkey;
|
||||
return info;
|
||||
}
|
||||
|
||||
const void* IbMr::getBuff() const
|
||||
{
|
||||
return this->buff;
|
||||
}
|
||||
|
||||
uint32_t IbMr::getLkey() const
|
||||
{
|
||||
return reinterpret_cast<struct ibv_mr*>(this->mr)->lkey;
|
||||
}
|
||||
|
||||
IbQp::IbQp(void* ctx, void* pd, int port)
|
||||
{
|
||||
struct ibv_context* _ctx = reinterpret_cast<struct ibv_context*>(ctx);
|
||||
struct ibv_pd* _pd = reinterpret_cast<struct ibv_pd*>(pd);
|
||||
|
||||
this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0);
|
||||
if (this->cq == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_cq failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
struct ibv_qp_init_attr qpInitAttr;
|
||||
std::memset(&qpInitAttr, 0, sizeof(qpInitAttr));
|
||||
qpInitAttr.sq_sig_all = 0;
|
||||
qpInitAttr.send_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.recv_cq = reinterpret_cast<struct ibv_cq*>(this->cq);
|
||||
qpInitAttr.qp_type = IBV_QPT_RC;
|
||||
qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE;
|
||||
qpInitAttr.cap.max_send_sge = 1;
|
||||
qpInitAttr.cap.max_recv_sge = 1;
|
||||
qpInitAttr.cap.max_inline_data = 0;
|
||||
|
||||
struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr);
|
||||
if (_qp == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(_ctx, port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->info.lid = portAttr.lid;
|
||||
this->info.port = port;
|
||||
this->info.linkLayer = portAttr.link_layer;
|
||||
this->info.qpn = _qp->qp_num;
|
||||
this->info.mtu = portAttr.active_mtu;
|
||||
if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) {
|
||||
union ibv_gid gid;
|
||||
if (ibv_query_gid(_ctx, port, 0, &gid) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_gid failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->info.spn = gid.global.subnet_prefix;
|
||||
}
|
||||
|
||||
struct ibv_qp_attr qpAttr;
|
||||
memset(&qpAttr, 0, sizeof(qpAttr));
|
||||
qpAttr.qp_state = IBV_QPS_INIT;
|
||||
qpAttr.pkey_index = 0;
|
||||
qpAttr.port_num = port;
|
||||
qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
|
||||
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->qp = _qp;
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_send_wr**>(&this->wrs), MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_sge**>(&this->sges), MSCCLPP_IB_MAX_SENDS));
|
||||
MSCCLPPTHROW(mscclppCalloc(reinterpret_cast<struct ibv_wc**>(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM));
|
||||
}
|
||||
|
||||
IbQp::~IbQp()
|
||||
{
|
||||
ibv_destroy_qp(reinterpret_cast<struct ibv_qp*>(this->qp));
|
||||
ibv_destroy_cq(reinterpret_cast<struct ibv_cq*>(this->cq));
|
||||
std::free(this->wrs);
|
||||
std::free(this->sges);
|
||||
std::free(this->wcs);
|
||||
}
|
||||
|
||||
void IbQp::rtr(const IbQpInfo& info)
|
||||
{
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
qp_attr.qp_state = IBV_QPS_RTR;
|
||||
qp_attr.path_mtu = info->mtu;
|
||||
qp_attr.dest_qp_num = info->qpn;
|
||||
qp_attr.path_mtu = static_cast<ibv_mtu>(info.mtu);
|
||||
qp_attr.dest_qp_num = info.qpn;
|
||||
qp_attr.rq_psn = 0;
|
||||
qp_attr.max_dest_rd_atomic = 1;
|
||||
qp_attr.min_rnr_timer = 0x12;
|
||||
if (info->linkLayer == IBV_LINK_LAYER_ETHERNET) {
|
||||
if (info.linkLayer == IBV_LINK_LAYER_ETHERNET) {
|
||||
qp_attr.ah_attr.is_global = 1;
|
||||
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
|
||||
qp_attr.ah_attr.grh.dgid.global.interface_id = info->lid;
|
||||
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn;
|
||||
qp_attr.ah_attr.grh.dgid.global.interface_id = info.lid;
|
||||
qp_attr.ah_attr.grh.flow_label = 0;
|
||||
qp_attr.ah_attr.grh.sgid_index = 0;
|
||||
qp_attr.ah_attr.grh.hop_limit = 255;
|
||||
qp_attr.ah_attr.grh.traffic_class = 0;
|
||||
} else {
|
||||
qp_attr.ah_attr.is_global = 0;
|
||||
qp_attr.ah_attr.dlid = info->lid;
|
||||
qp_attr.ah_attr.dlid = info.lid;
|
||||
}
|
||||
qp_attr.ah_attr.sl = 0;
|
||||
qp_attr.ah_attr.src_path_bits = 0;
|
||||
qp_attr.ah_attr.port_num = info->port;
|
||||
return ibv_modify_qp(this->qp, &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
|
||||
qp_attr.ah_attr.port_num = info.port;
|
||||
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
int mscclppIbQp::rts()
|
||||
void IbQp::rts()
|
||||
{
|
||||
struct ibv_qp_attr qp_attr;
|
||||
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
|
||||
@@ -328,75 +187,267 @@ int mscclppIbQp::rts()
|
||||
qp_attr.rnr_retry = 7;
|
||||
qp_attr.sq_psn = 0;
|
||||
qp_attr.max_rd_atomic = 1;
|
||||
return ibv_modify_qp(this->qp, &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
|
||||
IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
int ret = ibv_modify_qp(reinterpret_cast<struct ibv_qp*>(this->qp), &qp_attr,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
|
||||
IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
int mscclppIbQp::stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
|
||||
uint64_t srcOffset, uint64_t dstOffset, bool signaled)
|
||||
int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled)
|
||||
{
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
return -1;
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
struct ibv_send_wr* wr_ = &this->wrs[wrn];
|
||||
struct ibv_sge* sge_ = &this->sges[wrn];
|
||||
// std::memset(wr_, 0, sizeof(struct ibv_send_wr));
|
||||
// std::memset(sge_, 0, sizeof(struct ibv_sge));
|
||||
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
|
||||
struct ibv_sge* sges_ = reinterpret_cast<struct ibv_sge*>(this->sges);
|
||||
|
||||
struct ibv_send_wr* wr_ = &wrs_[wrn];
|
||||
struct ibv_sge* sge_ = &sges_[wrn];
|
||||
wr_->wr_id = wrId;
|
||||
wr_->sg_list = sge_;
|
||||
wr_->num_sge = 1;
|
||||
wr_->opcode = IBV_WR_RDMA_WRITE;
|
||||
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset;
|
||||
wr_->wr.rdma.rkey = info->rkey;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset;
|
||||
wr_->wr.rdma.rkey = info.rkey;
|
||||
wr_->next = nullptr;
|
||||
sge_->addr = (uint64_t)(ibMr->buff) + srcOffset;
|
||||
sge_->addr = (uint64_t)(mr->getBuff()) + srcOffset;
|
||||
sge_->length = size;
|
||||
sge_->lkey = ibMr->mr->lkey;
|
||||
sge_->lkey = mr->getLkey();
|
||||
if (wrn > 0) {
|
||||
this->wrs[wrn - 1].next = wr_;
|
||||
wrs_[wrn - 1].next = wr_;
|
||||
}
|
||||
this->wrn++;
|
||||
return this->wrn;
|
||||
}
|
||||
|
||||
int mscclppIbQp::stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
|
||||
uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData)
|
||||
int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData)
|
||||
{
|
||||
int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
this->wrs[wrn - 1].imm_data = immData;
|
||||
int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
struct ibv_send_wr* wrs_ = reinterpret_cast<struct ibv_send_wr*>(this->wrs);
|
||||
wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wrs_[wrn - 1].imm_data = immData;
|
||||
return wrn;
|
||||
}
|
||||
|
||||
int mscclppIbQp::postSend()
|
||||
void IbQp::postSend()
|
||||
{
|
||||
if (this->wrn == 0) {
|
||||
return 0;
|
||||
return;
|
||||
}
|
||||
|
||||
struct ibv_send_wr* bad_wr;
|
||||
int ret = ibv_post_send(this->qp, this->wrs, &bad_wr);
|
||||
int ret = ibv_post_send(reinterpret_cast<struct ibv_qp*>(this->qp), reinterpret_cast<struct ibv_send_wr*>(this->wrs),
|
||||
&bad_wr);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
std::stringstream err;
|
||||
err << "ibv_post_send failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->wrn = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mscclppIbQp::postRecv(uint64_t wrId)
|
||||
void IbQp::postRecv(uint64_t wrId)
|
||||
{
|
||||
struct ibv_recv_wr wr, *bad_wr;
|
||||
wr.wr_id = wrId;
|
||||
wr.sg_list = nullptr;
|
||||
wr.num_sge = 0;
|
||||
wr.next = nullptr;
|
||||
return ibv_post_recv(this->qp, &wr, &bad_wr);
|
||||
int ret = ibv_post_recv(reinterpret_cast<struct ibv_qp*>(this->qp), &wr, &bad_wr);
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_recv failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
int mscclppIbQp::pollCq()
|
||||
int IbQp::pollCq()
|
||||
{
|
||||
return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs);
|
||||
return ibv_poll_cq(reinterpret_cast<struct ibv_cq*>(this->cq), MSCCLPP_IB_CQ_POLL_NUM,
|
||||
reinterpret_cast<struct ibv_wc*>(this->wcs));
|
||||
}
|
||||
|
||||
IbQpInfo& IbQp::getInfo()
|
||||
{
|
||||
return this->info;
|
||||
}
|
||||
|
||||
const void* IbQp::getWc(int idx) const
|
||||
{
|
||||
return &reinterpret_cast<struct ibv_wc*>(this->wcs)[idx];
|
||||
}
|
||||
|
||||
IbCtx::IbCtx(const std::string& devName) : devName(devName)
|
||||
{
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (std::string(devices[i]->name) == devName) {
|
||||
this->ctx = ibv_open_device(devices[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ibv_free_device_list(devices);
|
||||
if (this->ctx == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
this->pd = ibv_alloc_pd(reinterpret_cast<struct ibv_context*>(this->ctx));
|
||||
if (this->pd == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_alloc_pd failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
IbCtx::~IbCtx()
|
||||
{
|
||||
this->mrs.clear();
|
||||
this->qps.clear();
|
||||
if (this->pd != nullptr) {
|
||||
ibv_dealloc_pd(reinterpret_cast<struct ibv_pd*>(this->pd));
|
||||
}
|
||||
if (this->ctx != nullptr) {
|
||||
ibv_close_device(reinterpret_cast<struct ibv_context*>(this->ctx));
|
||||
}
|
||||
}
|
||||
|
||||
bool IbCtx::isPortUsable(int port) const
|
||||
{
|
||||
struct ibv_port_attr portAttr;
|
||||
if (ibv_query_port(reinterpret_cast<struct ibv_context*>(this->ctx), port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
return portAttr.state == IBV_PORT_ACTIVE &&
|
||||
(portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
|
||||
}
|
||||
|
||||
int IbCtx::getAnyActivePort() const
|
||||
{
|
||||
struct ibv_device_attr devAttr;
|
||||
if (ibv_query_device(reinterpret_cast<struct ibv_context*>(this->ctx), &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
|
||||
if (this->isPortUsable(port)) {
|
||||
return port;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
IbQp* IbCtx::createQp(int port /*=-1*/)
|
||||
{
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
throw std::runtime_error("No active port found");
|
||||
}
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw std::runtime_error("invalid IB port: " + std::to_string(port));
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port));
|
||||
return qps.back().get();
|
||||
}
|
||||
|
||||
const IbMr* IbCtx::registerMr(void* buff, std::size_t size)
|
||||
{
|
||||
mrs.emplace_back(new IbMr(this->pd, buff, size));
|
||||
return mrs.back().get();
|
||||
}
|
||||
|
||||
const std::string& IbCtx::getDevName() const
|
||||
{
|
||||
return this->devName;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int getIBDeviceCount()
|
||||
{
|
||||
int num;
|
||||
ibv_get_device_list(&num);
|
||||
return num;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport)
|
||||
{
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
int ibTransportIndex;
|
||||
switch (ibTransport) { // TODO: get rid of this ugly switch
|
||||
case Transport::IB0:
|
||||
ibTransportIndex = 0;
|
||||
break;
|
||||
case Transport::IB1:
|
||||
ibTransportIndex = 1;
|
||||
break;
|
||||
case Transport::IB2:
|
||||
ibTransportIndex = 2;
|
||||
break;
|
||||
case Transport::IB3:
|
||||
ibTransportIndex = 3;
|
||||
break;
|
||||
case Transport::IB4:
|
||||
ibTransportIndex = 4;
|
||||
break;
|
||||
case Transport::IB5:
|
||||
ibTransportIndex = 5;
|
||||
break;
|
||||
case Transport::IB6:
|
||||
ibTransportIndex = 6;
|
||||
break;
|
||||
case Transport::IB7:
|
||||
ibTransportIndex = 7;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not an IB transport");
|
||||
}
|
||||
if (ibTransportIndex >= num) {
|
||||
throw std::runtime_error("IB transport out of range");
|
||||
}
|
||||
return devices[ibTransportIndex]->name;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName)
|
||||
{
|
||||
int num;
|
||||
struct ibv_device** devices = ibv_get_device_list(&num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (ibDeviceName == devices[i]->name) {
|
||||
switch (i) { // TODO: get rid of this ugly switch
|
||||
case 0:
|
||||
return Transport::IB0;
|
||||
case 1:
|
||||
return Transport::IB1;
|
||||
case 2:
|
||||
return Transport::IB2;
|
||||
case 3:
|
||||
return Transport::IB3;
|
||||
case 4:
|
||||
return Transport::IB4;
|
||||
case 5:
|
||||
return Transport::IB5;
|
||||
case 6:
|
||||
return Transport::IB6;
|
||||
case 7:
|
||||
return Transport::IB7;
|
||||
default:
|
||||
throw std::runtime_error("IB device index out of range");
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("IB device not found");
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
#define MSCCLPP_BASIC_PROXY_SERVICE_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "communicator.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm);
|
||||
ProxyHandler makeBasicProxyHandler(Communicator::Impl& comm);
|
||||
|
||||
}
|
||||
|
||||
|
||||
305
src/include/channel.hpp
Normal file
305
src/include/channel.hpp
Normal file
@@ -0,0 +1,305 @@
|
||||
#ifndef MSCCLPP_CHANNEL_HPP_
|
||||
#define MSCCLPP_CHANNEL_HPP_
|
||||
|
||||
#include "epoch.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include "mscclppfifo.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace channel {
|
||||
|
||||
// A Channel pairs a Connection with an Epoch
|
||||
class Channel
|
||||
{
|
||||
public:
|
||||
Channel(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
: connection_(connection), epoch_(std::make_shared<Epoch>(communicator, connection)) {};
|
||||
|
||||
Connection& connection() { return *connection_; }
|
||||
Epoch& epoch() { return *epoch_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
std::shared_ptr<Epoch> epoch_;
|
||||
};
|
||||
|
||||
using ChannelId = uint32_t;
|
||||
|
||||
using TriggerType = uint64_t;
|
||||
const TriggerType TriggerData = 0x1;
|
||||
const TriggerType TriggerFlag = 0x2;
|
||||
const TriggerType TriggerSync = 0x4;
|
||||
|
||||
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
|
||||
// mapping to the actual
|
||||
using MemoryId = uint32_t;
|
||||
|
||||
#define MSCCLPP_BITS_SIZE 32
|
||||
#define MSCCLPP_BITS_OFFSET 32
|
||||
#define MSCCLPP_BITS_REGMEM_HANDLE 8
|
||||
#define MSCCLPP_BITS_TYPE 3
|
||||
#define MSCCLPP_BITS_CONNID 10
|
||||
|
||||
// this is the basic structure of each work element in the fifo
|
||||
// the summation of number of bits must be 128 or less
|
||||
union ChannelTrigger {
|
||||
ProxyTrigger value;
|
||||
struct
|
||||
{
|
||||
// first 64 bits: value[0]
|
||||
uint64_t size : MSCCLPP_BITS_SIZE;
|
||||
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
// second 64 bits: value[1]
|
||||
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
|
||||
uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t chanId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE -
|
||||
MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ ChannelTrigger()
|
||||
{
|
||||
}
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value)
|
||||
{
|
||||
}
|
||||
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size, int connectionId)
|
||||
{
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
|
||||
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
|
||||
<< MSCCLPP_BITS_REGMEM_HANDLE) +
|
||||
src)
|
||||
<< MSCCLPP_BITS_OFFSET) +
|
||||
dstOffset);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct DeviceChannel
|
||||
{
|
||||
DeviceChannel() = default;
|
||||
|
||||
DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo) : channelId_(channelId), epoch_(epoch), fifo_(fifo) {}
|
||||
|
||||
DeviceChannel(const DeviceChannel& other) = default;
|
||||
|
||||
DeviceChannel& operator=(DeviceChannel& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
|
||||
uint64_t size)
|
||||
{
|
||||
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
put(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
epochIncrement();
|
||||
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
fifo_.push(
|
||||
ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_)
|
||||
.value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignal(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
|
||||
uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst,
|
||||
dstOffset, src, srcOffset, size, channelId_)
|
||||
.value);
|
||||
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo_.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset,
|
||||
uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo_.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, channelId_).value);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo_.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
epoch_.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
epoch_.epochIncrement();
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
ChannelId channelId_;
|
||||
|
||||
DeviceEpoch epoch_;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
DeviceProxyFifo fifo_;
|
||||
};
|
||||
|
||||
class DeviceChannelService;
|
||||
|
||||
inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService);
|
||||
|
||||
class DeviceChannelService {
|
||||
public:
|
||||
DeviceChannelService(Communicator& communicator);
|
||||
|
||||
ChannelId addChannel(std::shared_ptr<Connection> connection) {
|
||||
channels_.push_back(Channel(communicator_, connection));
|
||||
return channels_.size() - 1;
|
||||
}
|
||||
|
||||
MemoryId addMemory(RegisteredMemory memory) {
|
||||
memories_.push_back(memory);
|
||||
return memories_.size() - 1;
|
||||
}
|
||||
|
||||
Channel channel(ChannelId id) { return channels_[id]; }
|
||||
DeviceChannel deviceChannel(ChannelId id) { return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo()); }
|
||||
|
||||
void startProxy() { proxy_.start(); }
|
||||
void stopProxy() { proxy_.stop(); }
|
||||
|
||||
private:
|
||||
Communicator& communicator_;
|
||||
std::vector<Channel> channels_;
|
||||
std::vector<RegisteredMemory> memories_;
|
||||
Proxy proxy_;
|
||||
int deviceNumaNode;
|
||||
|
||||
void bindThread();
|
||||
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) {
|
||||
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
|
||||
Channel& channel = channels_[trigger->fields.chanId];
|
||||
|
||||
auto result = ProxyHandlerResult::Continue;
|
||||
|
||||
if (trigger->fields.type & TriggerData) {
|
||||
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
|
||||
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
|
||||
channel.connection().write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, trigger->fields.size);
|
||||
}
|
||||
|
||||
if (trigger->fields.type & TriggerFlag) {
|
||||
channel.epoch().signal();
|
||||
}
|
||||
|
||||
if (trigger->fields.type & TriggerSync) {
|
||||
channel.connection().flush();
|
||||
result = ProxyHandlerResult::FlushFifoTailAndContinue;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct SimpleDeviceChannel
|
||||
{
|
||||
SimpleDeviceChannel() = default;
|
||||
|
||||
SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {}
|
||||
|
||||
SimpleDeviceChannel(const SimpleDeviceChannel& other) = default;
|
||||
|
||||
SimpleDeviceChannel& operator=(SimpleDeviceChannel& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devChan_.put(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
|
||||
{
|
||||
put(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
devChan_.signal();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignal(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
devChan_.flush();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
devChan_.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
devChan_.epochIncrement();
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
DeviceChannel devChan_;
|
||||
MemoryId dst_;
|
||||
MemoryId src_;
|
||||
};
|
||||
|
||||
} // namespace channel
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_CHANNEL_HPP_
|
||||
@@ -8,6 +8,7 @@
|
||||
#define MSCCLPP_CHECKS_HPP_
|
||||
|
||||
#include "debug.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#define MSCCLPPTHROW(call) \
|
||||
@@ -16,7 +17,7 @@
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \
|
||||
} \
|
||||
} while (0);
|
||||
} while (false)
|
||||
|
||||
#define CUDATHROW(cmd) \
|
||||
do { \
|
||||
@@ -26,4 +27,14 @@
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define CUTHROW(cmd) \
|
||||
do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
const char* errStr; \
|
||||
cuGetErrorString(err, &errStr); \
|
||||
throw std::runtime_error(std::string("Cu failure '") + std::string(errStr) + "'"); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -7,15 +7,16 @@
|
||||
#ifndef MSCCLPP_COMM_H_
|
||||
#define MSCCLPP_COMM_H_
|
||||
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "proxy.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#define MAXCONNECTIONS 64
|
||||
|
||||
struct mscclppBufferRegistration
|
||||
{
|
||||
void *data;
|
||||
void* data;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
@@ -31,7 +32,7 @@ struct mscclppConn
|
||||
std::vector<mscclppBufferRegistration> bufferRegistrations;
|
||||
std::vector<mscclppBufferRegistration> remoteBufferRegistrations;
|
||||
|
||||
struct mscclppIbContext* ibCtx;
|
||||
mscclpp::IbCtx* ibCtx;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
std::vector<uint64_t> npkitUsedReqIds;
|
||||
std::vector<uint64_t> npkitFreeReqIds;
|
||||
@@ -57,7 +58,7 @@ struct mscclppComm
|
||||
// Flag to ask MSCCLPP kernels to abort
|
||||
volatile uint32_t* abortFlag;
|
||||
|
||||
struct mscclppIbContext* ibContext[MSCCLPP_IB_MAX_DEVS];
|
||||
std::unique_ptr<mscclpp::IbCtx> ibContext[MSCCLPP_IB_MAX_DEVS];
|
||||
struct mscclppProxyState* proxyState[MSCCLPP_PROXY_MAX_NUM];
|
||||
};
|
||||
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
#ifndef MSCCL_COMMUNICATOR_HPP_
|
||||
#define MSCCL_COMMUNICATOR_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "proxy.hpp"
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct Communicator::Impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<HostConnection>> connections;
|
||||
Proxy proxy;
|
||||
class ConnectionBase;
|
||||
|
||||
Impl();
|
||||
struct Communicator::Impl
|
||||
{
|
||||
std::vector<std::shared_ptr<ConnectionBase>> connections_;
|
||||
std::vector<std::shared_ptr<Setuppable>> toSetup_;
|
||||
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
|
||||
std::shared_ptr<BaseBootstrap> bootstrap_;
|
||||
std::vector<uint64_t> rankToHash_;
|
||||
|
||||
~Impl();
|
||||
Impl(std::shared_ptr<BaseBootstrap> bootstrap);
|
||||
|
||||
friend class HostConnection;
|
||||
~Impl();
|
||||
|
||||
IbCtx* getIbContext(Transport ibTransport);
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif
|
||||
#endif // MSCCL_COMMUNICATOR_HPP_
|
||||
|
||||
69
src/include/connection.hpp
Normal file
69
src/include/connection.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#ifndef MSCCLPP_CONNECTION_HPP_
|
||||
#define MSCCLPP_CONNECTION_HPP_
|
||||
|
||||
#include "communicator.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
// TODO: Add functionality to these classes for Communicator to do connectionSetup
|
||||
|
||||
class ConnectionBase : public Connection, public Setuppable
|
||||
{
|
||||
int remoteRank_;
|
||||
int tag_;
|
||||
public:
|
||||
ConnectionBase(int remoteRank, int tag);
|
||||
|
||||
int remoteRank() override;
|
||||
int tag() override;
|
||||
};
|
||||
|
||||
class CudaIpcConnection : public ConnectionBase
|
||||
{
|
||||
cudaStream_t stream;
|
||||
|
||||
public:
|
||||
CudaIpcConnection(int remoteRank, int tag);
|
||||
|
||||
~CudaIpcConnection();
|
||||
|
||||
Transport transport() override;
|
||||
|
||||
Transport remoteTransport() override;
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) override;
|
||||
|
||||
void flush() override;
|
||||
};
|
||||
|
||||
class IBConnection : public ConnectionBase
|
||||
{
|
||||
Transport transport_;
|
||||
Transport remoteTransport_;
|
||||
IbQp* qp;
|
||||
int numSignaledSends;
|
||||
|
||||
public:
|
||||
IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl);
|
||||
|
||||
Transport transport() override;
|
||||
|
||||
Transport remoteTransport() override;
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) override;
|
||||
|
||||
void flush() override;
|
||||
|
||||
void beginSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
|
||||
|
||||
void endSetup(std::shared_ptr<BaseBootstrap> bootstrap) override;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_CONNECTION_HPP_
|
||||
52
src/include/epoch.hpp
Normal file
52
src/include/epoch.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#ifndef MSCCLPP_EPOCH_HPP_
|
||||
#define MSCCLPP_EPOCH_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct alignas(16) EpochIds
|
||||
{
|
||||
uint64_t outbound_;
|
||||
uint64_t inboundReplica_;
|
||||
};
|
||||
|
||||
struct DeviceEpoch
|
||||
{
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*expectedInboundEpochId_) += 1;
|
||||
while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(epochIds_->outbound_) += 1;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
EpochIds* epochIds_;
|
||||
uint64_t* expectedInboundEpochId_;
|
||||
};
|
||||
|
||||
class Epoch
|
||||
{
|
||||
std::shared_ptr<Connection> connection_;
|
||||
DeviceEpoch device_;
|
||||
RegisteredMemory localEpochIdsRegMem_;
|
||||
NonblockingFuture<RegisteredMemory> remoteEpochIdsRegMem_;
|
||||
|
||||
public:
|
||||
Epoch(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
Epoch(const Epoch&) = delete;
|
||||
~Epoch();
|
||||
|
||||
void signal();
|
||||
|
||||
DeviceEpoch deviceEpoch() { return device_; }
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EPOCH_HPP_
|
||||
@@ -1,22 +0,0 @@
|
||||
#ifndef MSCCLPP_HOST_CONNECTION_HPP_
|
||||
#define MSCCLPP_HOST_CONNECTION_HPP_
|
||||
|
||||
#include "mscclpp.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "comm.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct HostConnection::Impl {
|
||||
Communicator* comm;
|
||||
mscclppConn* conn;
|
||||
mscclppHostConn_t* hostConn;
|
||||
|
||||
Impl(Communicator* comm, mscclppConn* conn);
|
||||
|
||||
~Impl();
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_HOST_CONNECTION_HPP_
|
||||
@@ -1,69 +0,0 @@
|
||||
#ifndef MSCCLPP_IB_H_
|
||||
#define MSCCLPP_IB_H_
|
||||
|
||||
#include "mscclpp.h"
|
||||
#include <infiniband/verbs.h>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#define MSCCLPP_IB_CQ_SIZE 1024
|
||||
#define MSCCLPP_IB_CQ_POLL_NUM 4
|
||||
#define MSCCLPP_IB_MAX_SENDS 64
|
||||
#define MSCCLPP_IB_MAX_DEVS 8
|
||||
|
||||
// QP info to be shared with the remote peer
|
||||
struct mscclppIbQpInfo
|
||||
{
|
||||
uint16_t lid;
|
||||
uint8_t port;
|
||||
uint8_t linkLayer;
|
||||
uint32_t qpn;
|
||||
uint64_t spn;
|
||||
ibv_mtu mtu;
|
||||
};
|
||||
|
||||
// IB queue pair
|
||||
struct mscclppIbQp
|
||||
{
|
||||
struct ibv_qp* qp;
|
||||
struct mscclppIbQpInfo info;
|
||||
struct ibv_send_wr* wrs;
|
||||
struct ibv_sge* sges;
|
||||
struct ibv_cq* cq;
|
||||
struct ibv_wc* wcs;
|
||||
int wrn;
|
||||
|
||||
int rtr(const mscclppIbQpInfo* info);
|
||||
int rts();
|
||||
int stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled);
|
||||
int stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId,
|
||||
uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
int postSend();
|
||||
int postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
};
|
||||
|
||||
// Holds resources of a single IB device.
|
||||
struct mscclppIbContext
|
||||
{
|
||||
struct ibv_context* ctx;
|
||||
struct ibv_pd* pd;
|
||||
int* ports;
|
||||
int nPorts;
|
||||
struct mscclppIbQp* qps;
|
||||
int nQps;
|
||||
int maxQps;
|
||||
struct mscclppIbMr* mrs;
|
||||
int nMrs;
|
||||
int maxMrs;
|
||||
};
|
||||
|
||||
mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName);
|
||||
mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx);
|
||||
mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port = -1);
|
||||
mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size,
|
||||
struct mscclppIbMr** ibMr);
|
||||
|
||||
#endif
|
||||
108
src/include/ib.hpp
Normal file
108
src/include/ib.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#ifndef MSCCLPP_IB_HPP_
|
||||
#define MSCCLPP_IB_HPP_
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#define MSCCLPP_IB_CQ_SIZE 1024
|
||||
#define MSCCLPP_IB_CQ_POLL_NUM 1
|
||||
#define MSCCLPP_IB_MAX_SENDS 64
|
||||
#define MSCCLPP_IB_MAX_DEVS 8
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct IbMrInfo
|
||||
{
|
||||
uint64_t addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
|
||||
class IbMr
|
||||
{
|
||||
public:
|
||||
~IbMr();
|
||||
|
||||
IbMrInfo getInfo() const;
|
||||
const void* getBuff() const;
|
||||
uint32_t getLkey() const;
|
||||
|
||||
private:
|
||||
IbMr(void* pd, void* buff, std::size_t size);
|
||||
|
||||
void* mr;
|
||||
void* buff;
|
||||
std::size_t size;
|
||||
|
||||
friend class IbCtx;
|
||||
};
|
||||
|
||||
// QP info to be shared with the remote peer
|
||||
struct IbQpInfo
|
||||
{
|
||||
uint16_t lid;
|
||||
uint8_t port;
|
||||
uint8_t linkLayer;
|
||||
uint32_t qpn;
|
||||
uint64_t spn;
|
||||
int mtu;
|
||||
};
|
||||
|
||||
class IbQp
|
||||
{
|
||||
public:
|
||||
~IbQp();
|
||||
|
||||
void rtr(const IbQpInfo& info);
|
||||
void rts();
|
||||
int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled);
|
||||
int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset,
|
||||
uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
void postSend();
|
||||
void postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
|
||||
IbQpInfo& getInfo();
|
||||
const void* getWc(int idx) const;
|
||||
|
||||
private:
|
||||
IbQp(void* ctx, void* pd, int port);
|
||||
|
||||
IbQpInfo info;
|
||||
|
||||
void* qp;
|
||||
void* cq;
|
||||
void* wcs;
|
||||
void* wrs;
|
||||
void* sges;
|
||||
int wrn;
|
||||
|
||||
friend class IbCtx;
|
||||
};
|
||||
|
||||
class IbCtx
|
||||
{
|
||||
public:
|
||||
IbCtx(const std::string& devName);
|
||||
~IbCtx();
|
||||
|
||||
IbQp* createQp(int port = -1);
|
||||
const IbMr* registerMr(void* buff, std::size_t size);
|
||||
|
||||
const std::string& getDevName() const;
|
||||
|
||||
private:
|
||||
bool isPortUsable(int port) const;
|
||||
int getAnyActivePort() const;
|
||||
|
||||
const std::string devName;
|
||||
void* ctx;
|
||||
void* pd;
|
||||
std::list<std::unique_ptr<IbQp>> qps;
|
||||
std::list<std::unique_ptr<IbMr>> mrs;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_IB_HPP_
|
||||
@@ -191,7 +191,8 @@ struct mscclppHostConn
|
||||
{
|
||||
virtual ~mscclppHostConn() = default;
|
||||
virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0;
|
||||
virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize) = 0;
|
||||
virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
|
||||
uint64_t dataSize) = 0;
|
||||
virtual void signal() = 0;
|
||||
virtual void wait() = 0;
|
||||
virtual void flush() = 0;
|
||||
@@ -207,25 +208,10 @@ typedef struct
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
} mscclppUniqueId;
|
||||
|
||||
// MR info to be shared with the remote peer
|
||||
struct mscclppIbMrInfo
|
||||
{
|
||||
uint64_t addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
|
||||
// IB memory region
|
||||
struct mscclppIbMr
|
||||
{
|
||||
struct ibv_mr* mr;
|
||||
void* buff;
|
||||
struct mscclppIbMrInfo info;
|
||||
};
|
||||
|
||||
struct mscclppRegisteredMemoryP2P
|
||||
{
|
||||
void* remoteBuff;
|
||||
mscclppIbMr* IbMr;
|
||||
const void* IbMr;
|
||||
};
|
||||
|
||||
struct mscclppRegisteredMemory
|
||||
@@ -247,7 +233,6 @@ typedef enum
|
||||
mscclppNumResults = 8
|
||||
} mscclppResult_t;
|
||||
|
||||
|
||||
/* Create a unique ID for communication. Only needs to be called by one process.
|
||||
* Use with mscclppCommInitRankFromId().
|
||||
* All processes need to provide the same ID to mscclppCommInitRankFromId().
|
||||
@@ -358,7 +343,8 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void
|
||||
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
|
||||
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
|
||||
*/
|
||||
mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev = 0);
|
||||
mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag,
|
||||
mscclppTransport_t transportType, const char* ibDev = 0);
|
||||
|
||||
/* Register a buffer for use with a connection.
|
||||
*
|
||||
@@ -371,7 +357,8 @@ mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank,
|
||||
* Outputs:
|
||||
* handle: a handle to the buffer registration
|
||||
*/
|
||||
mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle);
|
||||
mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize,
|
||||
mscclppBufferHandle_t* handle);
|
||||
|
||||
/* Establish all connections declared by mscclppConnect(). This function must be called after all mscclppConnect()
|
||||
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
|
||||
|
||||
@@ -6,22 +6,17 @@
|
||||
#define MSCCLPP_PATCH 0
|
||||
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)
|
||||
|
||||
// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered.
|
||||
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
|
||||
#define MSCCLPP_PROXY_FIFO_SIZE 128
|
||||
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
|
||||
|
||||
#include <vector>
|
||||
#include <bitset>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
#include <mscclppfifo.hpp>
|
||||
#include <vector>
|
||||
#include <future>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
struct UniqueId {
|
||||
struct UniqueId
|
||||
{
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
};
|
||||
|
||||
@@ -36,6 +31,21 @@ public:
|
||||
virtual void recv(void* data, int size, int peer, int tag) = 0;
|
||||
virtual void allGather(void* allData, int size) = 0;
|
||||
virtual void barrier() = 0;
|
||||
|
||||
// TODO: move implementations of these helpers out of this header
|
||||
void send(const std::vector<char>& data, int peer, int tag)
|
||||
{
|
||||
size_t size = data.size();
|
||||
send((void*)&size, sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag+1);
|
||||
}
|
||||
void recv(std::vector<char>& data, int peer, int tag)
|
||||
{
|
||||
size_t size;
|
||||
recv((void*)&size, sizeof(size_t), peer, tag);
|
||||
data.resize(size);
|
||||
recv((void*)data.data(), data.size(), peer, tag+1);
|
||||
}
|
||||
};
|
||||
|
||||
class Bootstrap : public BaseBootstrap
|
||||
@@ -61,369 +71,6 @@ private:
|
||||
std::unique_ptr<Impl> pimpl_;
|
||||
};
|
||||
|
||||
|
||||
struct alignas(16) SignalEpochId {
|
||||
// every signal(), increaments this and either:
|
||||
// 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy
|
||||
// 2) gpu thread directly writes it to remoteSignalEpochId->device
|
||||
uint64_t device;
|
||||
// signal() function triggers the cpu proxy thread to write to it
|
||||
uint64_t proxy;
|
||||
};
|
||||
|
||||
using ChannelTriggerType = uint64_t;
|
||||
const ChannelTriggerType channelTriggerData = 0x1;
|
||||
const ChannelTriggerType channelTriggerFlag = 0x2;
|
||||
const ChannelTriggerType channelTriggerSync = 0x4;
|
||||
|
||||
// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles
|
||||
// mapping to the actual
|
||||
using BufferHandle = uint32_t;
|
||||
|
||||
#define MSCCLPP_BITS_SIZE 32
|
||||
#define MSCCLPP_BITS_OFFSET 32
|
||||
#define MSCCLPP_BITS_BUFFER_HANDLE 8
|
||||
#define MSCCLPP_BITS_TYPE 3
|
||||
#define MSCCLPP_BITS_CONNID 10
|
||||
|
||||
// this is the basic structure of each work element in the fifo
|
||||
// the summation of number of bits must be 128 or less
|
||||
union ChannelTrigger {
|
||||
ProxyTrigger value;
|
||||
struct
|
||||
{
|
||||
// first 64 bits: value[0]
|
||||
uint64_t size : MSCCLPP_BITS_SIZE;
|
||||
uint64_t srcOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
// second 64 bits: value[1]
|
||||
uint64_t dstOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
|
||||
uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t connId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ ChannelTrigger() {}
|
||||
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
|
||||
__device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) {
|
||||
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size);
|
||||
value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset);
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
};
|
||||
|
||||
struct ConnectionEpoch {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*waitEpochId) += 1;
|
||||
while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId))
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
*(volatile uint64_t*)&(localSignalEpochId->device) += 1;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
SignalEpochId* localSignalEpochId;
|
||||
// used by the signal() function directly from gpu
|
||||
SignalEpochId* remoteSignalEpochId;
|
||||
|
||||
// every wait(), increments this and then the gpu waits for either:
|
||||
// 1) localSignalEpochId->proxy to be >= this in case of a proxy thread
|
||||
// 2) remoteSignalEpochId->device to be >= this in case of a gpu thread
|
||||
uint64_t* waitEpochId;
|
||||
};
|
||||
|
||||
class HostConnection {
|
||||
struct Impl;
|
||||
public:
|
||||
/* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */
|
||||
HostConnection(std::unique_ptr<Impl>);
|
||||
|
||||
~HostConnection();
|
||||
|
||||
int getId();
|
||||
|
||||
/* Register a region of GPU memory for use with this connection. Must be called before connectionSetup()
|
||||
* in the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle registerBuffer(void* data, uint64_t size);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called.
|
||||
*
|
||||
* Returns: the number of buffers registered
|
||||
*/
|
||||
int numLocalBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
BufferHandle getLocalBuffer(int index);
|
||||
|
||||
/* Get the number of times registerBuffer(...) was called on the remote peer.
|
||||
*
|
||||
* Returns: the number of buffers registered on the remote peer
|
||||
*/
|
||||
int numRemoteBuffers();
|
||||
|
||||
/* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index
|
||||
*
|
||||
* Inputs:
|
||||
* index: the index of the handle to get
|
||||
*
|
||||
* Returns: a handle to the buffer on the remote peer
|
||||
*/
|
||||
BufferHandle getRemoteBuffer(int index);
|
||||
|
||||
ConnectionEpoch getEpoch();
|
||||
|
||||
DeviceProxyFifo getDeviceFifo();
|
||||
|
||||
void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
void signal();
|
||||
|
||||
void flush();
|
||||
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class Communicator;
|
||||
};
|
||||
|
||||
/***************************************************************************************************************
|
||||
* A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand.
|
||||
* The communication API is one-sided meaning that for every single data transfer, only one side
|
||||
* needs to execute unlike a two-sided communication stack such as NCCL where both sides
|
||||
* need to execute a send and a receive instruction, respectively, for every transfer.
|
||||
*
|
||||
* A connection is uniquely identified by the (remoteRank, tag) pair at an endpoint.
|
||||
* The two endpoints register buffers of the same size with the connection.
|
||||
*
|
||||
* The endpoints provide the remoteRank, tag, and the buffer when registering a connection with msccppConnect().
|
||||
*
|
||||
* mscllppConnectionSetup() sets up all the registered connections.
|
||||
*
|
||||
***************************************************************************************************************
|
||||
* A proxy thread running on the CPU is necessary to perform transfers using InfiniBand or the DMA engine.
|
||||
* The current implementation uses a single proxy thread per context - one IB connection or DMA engine per node.
|
||||
* Thus multiple threadblocks using different connections might use the same CPU proxy thread.
|
||||
*
|
||||
* Before using any of functionality of connections, mscclppProxyLaunch needs to be called to spawn the
|
||||
* proxy threads. There are currently two types of connections:
|
||||
*
|
||||
* P2P via NVLink: the DMA engine can perform the copy between the buffers. DMA engine has higher latency
|
||||
* but has a higher bandwidth and costs no compute cycles on the GPU.
|
||||
*
|
||||
* InfiniBand: the RDMA engine copies the data over MLX devices.
|
||||
*
|
||||
***************************************************************************************************************
|
||||
* At the runtime, a GPU kernel has access to a mscclppDevConn object that provides the following functions:
|
||||
*
|
||||
* put(): [non-blocking] the sender initiates a data transfer to the receiver.
|
||||
*
|
||||
* signal(): [non-blocking] the sender signals the receiver that data is ready to be consumed.
|
||||
*
|
||||
* flush(): [blocking] the sender waits for all the data transfers to complete
|
||||
*
|
||||
* wait(): [blocking] the reciever waits on the signal() to start reading the data.
|
||||
*
|
||||
* The sender should not reuse the buffer till the flush() returns.
|
||||
* The receiver should only access the data after the wait() returns.
|
||||
*
|
||||
* putWithSignal(): the sender initiates a data transfer and signals the receiver that data is ready to be consumed.
|
||||
* This is an optimized version of a put() followed by a signal().
|
||||
*
|
||||
* These functions hide the complexity of syncrhonization between the two GPUs and the CPU proxy thread.
|
||||
* Example:
|
||||
*
|
||||
* // sender GPU
|
||||
* devConn.put(data1)
|
||||
* // not OK to write to data1
|
||||
* devConn.put(data2)
|
||||
* // not OK to write to data1, data2
|
||||
* devConn.put(data3) // receiver GPU
|
||||
* // not OK to write to data1, data2, data3 // not OK to read data1, data2, data3
|
||||
* devConn.signal() -------------------------------> devConn.wait()
|
||||
* // not OK to write to data1, data2, data3 // OK to read data1, data2, data3
|
||||
* devConn.flush()
|
||||
* // OK to write to data1, data2, data3
|
||||
*
|
||||
*
|
||||
* The two endpoint can concurrently use the same connection provided they are writing (puts) on different
|
||||
* indices in the registered buffer.
|
||||
**************************************************************************************************************/
|
||||
struct DeviceConnection {
|
||||
DeviceConnection() = default;
|
||||
|
||||
DeviceConnection(HostConnection& hostConn)
|
||||
: connectionId(hostConn.getId()), epoch(hostConn.getEpoch()),
|
||||
fifo(hostConn.getDeviceFifo()) {}
|
||||
|
||||
DeviceConnection(const DeviceConnection& other) = default;
|
||||
|
||||
DeviceConnection& operator=(DeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
put(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
epochIncrement();
|
||||
fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignal(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value);
|
||||
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(dst, offset, src, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value);
|
||||
// we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail
|
||||
// to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0.
|
||||
while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 &&
|
||||
*(volatile uint64_t*)fifo.tailReplica <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
epoch.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
epoch.epochIncrement();
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
int connectionId;
|
||||
|
||||
ConnectionEpoch epoch;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
DeviceProxyFifo fifo;
|
||||
};
|
||||
|
||||
struct SimpleDeviceConnection {
|
||||
SimpleDeviceConnection() = default;
|
||||
|
||||
SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) {
|
||||
dst = hostConn.getRemoteBuffer(0);
|
||||
src = hostConn.getLocalBuffer(0);
|
||||
}
|
||||
|
||||
SimpleDeviceConnection(const SimpleDeviceConnection& other) = default;
|
||||
|
||||
SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.put(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void put(uint64_t offset, uint64_t size)
|
||||
{
|
||||
put(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
devConn.signal();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.putWithSignal(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignal(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size)
|
||||
{
|
||||
devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size)
|
||||
{
|
||||
putWithSignalAndFlush(offset, offset, size);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
devConn.flush();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void epochIncrement()
|
||||
{
|
||||
devConn.epochIncrement();
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
DeviceConnection devConn;
|
||||
BufferHandle dst;
|
||||
BufferHandle src;
|
||||
};
|
||||
|
||||
/* Create a unique ID for communication. Only needs to be called by one process.
|
||||
* Use with mscclppCommInitRankFromId().
|
||||
* All processes need to provide the same ID to mscclppCommInitRankFromId().
|
||||
@@ -433,120 +80,300 @@ struct SimpleDeviceConnection {
|
||||
*/
|
||||
std::unique_ptr<UniqueId> getUniqueId();
|
||||
|
||||
/* Transport Types */
|
||||
enum class TransportType : uint8_t {
|
||||
P2P = 0,
|
||||
IB = 1,
|
||||
enum class Transport
|
||||
{
|
||||
Unknown,
|
||||
CudaIpc,
|
||||
IB0,
|
||||
IB1,
|
||||
IB2,
|
||||
IB3,
|
||||
IB4,
|
||||
IB5,
|
||||
IB6,
|
||||
IB7,
|
||||
NumTransports
|
||||
};
|
||||
|
||||
class Communicator {
|
||||
public:
|
||||
namespace detail {
|
||||
const size_t TransportFlagsSize = 10;
|
||||
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
|
||||
"TransportFlagsSize must match the number of transports");
|
||||
using TransportFlagsBase = std::bitset<TransportFlagsSize>;
|
||||
} // namespace detail
|
||||
|
||||
/* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function.
|
||||
*
|
||||
* Inputs:
|
||||
* nranks: number of ranks in the communicator
|
||||
* ipPortPair: a string of the form "ip:port" that represents the address of the root process
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
Communicator(int nranks, const char* ipPortPair, int rank);
|
||||
|
||||
/* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that
|
||||
* id is provided by the user by calling getUniqueId()
|
||||
*
|
||||
* Inputs:
|
||||
* nranks: number of ranks in the communicator
|
||||
* id: the unique ID to be used for communication
|
||||
* rank: rank of the calling process
|
||||
*/
|
||||
Communicator(int nranks, UniqueId id, int rank);
|
||||
class TransportFlags : private detail::TransportFlagsBase
|
||||
{
|
||||
public:
|
||||
TransportFlags() = default;
|
||||
TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast<size_t>(transport))
|
||||
{
|
||||
}
|
||||
|
||||
bool has(Transport transport) const
|
||||
{
|
||||
return detail::TransportFlagsBase::test(static_cast<size_t>(transport));
|
||||
}
|
||||
|
||||
bool none() const
|
||||
{
|
||||
return detail::TransportFlagsBase::none();
|
||||
}
|
||||
|
||||
bool any() const
|
||||
{
|
||||
return detail::TransportFlagsBase::any();
|
||||
}
|
||||
|
||||
bool all() const
|
||||
{
|
||||
return detail::TransportFlagsBase::all();
|
||||
}
|
||||
|
||||
size_t count() const
|
||||
{
|
||||
return detail::TransportFlagsBase::count();
|
||||
}
|
||||
|
||||
TransportFlags& operator|=(TransportFlags other)
|
||||
{
|
||||
detail::TransportFlagsBase::operator|=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator|(TransportFlags other) const
|
||||
{
|
||||
return TransportFlags(*this) |= other;
|
||||
}
|
||||
|
||||
TransportFlags operator|(Transport transport) const
|
||||
{
|
||||
return *this | TransportFlags(transport);
|
||||
}
|
||||
|
||||
TransportFlags& operator&=(TransportFlags other)
|
||||
{
|
||||
detail::TransportFlagsBase::operator&=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator&(TransportFlags other) const
|
||||
{
|
||||
return TransportFlags(*this) &= other;
|
||||
}
|
||||
|
||||
TransportFlags operator&(Transport transport) const
|
||||
{
|
||||
return *this & TransportFlags(transport);
|
||||
}
|
||||
|
||||
TransportFlags& operator^=(TransportFlags other)
|
||||
{
|
||||
detail::TransportFlagsBase::operator^=(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
TransportFlags operator^(TransportFlags other) const
|
||||
{
|
||||
return TransportFlags(*this) ^= other;
|
||||
}
|
||||
|
||||
TransportFlags operator^(Transport transport) const
|
||||
{
|
||||
return *this ^ TransportFlags(transport);
|
||||
}
|
||||
|
||||
TransportFlags operator~() const
|
||||
{
|
||||
return TransportFlags(*this).flip();
|
||||
}
|
||||
|
||||
bool operator==(TransportFlags other) const
|
||||
{
|
||||
return detail::TransportFlagsBase::operator==(other);
|
||||
}
|
||||
|
||||
bool operator!=(TransportFlags other) const
|
||||
{
|
||||
return detail::TransportFlagsBase::operator!=(other);
|
||||
}
|
||||
|
||||
detail::TransportFlagsBase toBitset() const
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
inline TransportFlags operator|(Transport transport1, Transport transport2)
|
||||
{
|
||||
return TransportFlags(transport1) | transport2;
|
||||
}
|
||||
|
||||
inline TransportFlags operator&(Transport transport1, Transport transport2)
|
||||
{
|
||||
return TransportFlags(transport1) & transport2;
|
||||
}
|
||||
|
||||
inline TransportFlags operator^(Transport transport1, Transport transport2)
|
||||
{
|
||||
return TransportFlags(transport1) ^ transport2;
|
||||
}
|
||||
|
||||
const TransportFlags NoTransports = TransportFlags();
|
||||
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 |
|
||||
Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;
|
||||
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
|
||||
|
||||
int getIBDeviceCount();
|
||||
std::string getIBDeviceName(Transport ibTransport);
|
||||
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);
|
||||
|
||||
class Communicator;
|
||||
class Connection;
|
||||
|
||||
class RegisteredMemory
|
||||
{
|
||||
struct Impl;
|
||||
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated lazily.
|
||||
std::shared_ptr<Impl> pimpl;
|
||||
|
||||
public:
|
||||
RegisteredMemory() = default;
|
||||
RegisteredMemory(std::shared_ptr<Impl> pimpl);
|
||||
~RegisteredMemory();
|
||||
|
||||
void* data();
|
||||
size_t size();
|
||||
int rank();
|
||||
TransportFlags transports();
|
||||
|
||||
std::vector<char> serialize();
|
||||
static RegisteredMemory deserialize(const std::vector<char>& data);
|
||||
|
||||
friend class Connection;
|
||||
friend class Communicator;
|
||||
};
|
||||
|
||||
class Connection
|
||||
{
|
||||
public:
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) = 0;
|
||||
|
||||
virtual void flush() = 0;
|
||||
|
||||
virtual int remoteRank() = 0;
|
||||
|
||||
virtual int tag() = 0;
|
||||
|
||||
virtual Transport transport() = 0;
|
||||
|
||||
virtual Transport remoteTransport() = 0;
|
||||
|
||||
protected:
|
||||
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory&);
|
||||
};
|
||||
|
||||
struct Setuppable
|
||||
{
|
||||
virtual void beginSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap>) {}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class NonblockingFuture
|
||||
{
|
||||
std::shared_future<T> future;
|
||||
public:
|
||||
NonblockingFuture() = default;
|
||||
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}
|
||||
NonblockingFuture(const NonblockingFuture&) = default;
|
||||
|
||||
bool ready() const
|
||||
{
|
||||
return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
|
||||
}
|
||||
|
||||
T get()
|
||||
{
|
||||
if (!ready())
|
||||
throw std::runtime_error("NonblockingFuture::get() called before ready");
|
||||
return future.get();
|
||||
}
|
||||
};
|
||||
|
||||
class Communicator
|
||||
{
|
||||
public:
|
||||
/* Initialize the communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* bootstrap: an implementation of the of BaseBootstrap that the communicator will use
|
||||
*/
|
||||
Communicator(std::shared_ptr<BaseBootstrap> bootstrap);
|
||||
|
||||
~Communicator();
|
||||
|
||||
/* Ring-based AllGather through the bootstrap socket.
|
||||
*
|
||||
* Inputs:
|
||||
* data: data array to be gathered where `[r*size, (r+1)*size)` is the data for rank `r`
|
||||
* size: data size per rank
|
||||
*/
|
||||
void bootstrapAllGather(void* data, int size);
|
||||
|
||||
/* A no-op function that is used to synchronize all processes via a bootstrap allgather*/
|
||||
void bootstrapBarrier();
|
||||
/* Return the bootstrapper held by this communicator. */
|
||||
std::shared_ptr<BaseBootstrap> bootstrapper();
|
||||
|
||||
/* Register a region of GPU memory for use in this communicator.
|
||||
*
|
||||
* Inputs:
|
||||
* data: base pointer to the memory
|
||||
* size: size of the memory region in bytes
|
||||
*
|
||||
* Returns: a handle to the buffer
|
||||
*/
|
||||
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
|
||||
|
||||
void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag);
|
||||
|
||||
NonblockingFuture<RegisteredMemory> recvMemoryOnSetup(int remoteRank, int tag);
|
||||
|
||||
/* Connect to a remote rank. This function only prepares metadata for connection. The actual connection
|
||||
* is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection
|
||||
* from rank i to remote rank j needs to have a counterpart from rank j to rank i.
|
||||
* Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages
|
||||
* and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has
|
||||
* security risks if the devConn's accesses are given to a malicious process.
|
||||
*
|
||||
* Inputs:
|
||||
* remoteRank: the rank of the remote process
|
||||
* tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be
|
||||
* used to identify the connection inside a GPU kernel.
|
||||
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
|
||||
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
|
||||
*/
|
||||
std::shared_ptr<HostConnection> connect(int remoteRank, int tag, TransportType transportType, const char* ibDev = 0);
|
||||
* is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection
|
||||
* from rank i to remote rank j needs to have a counterpart from rank j to rank i.
|
||||
* Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages
|
||||
* and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has
|
||||
* security risks if the devConn's accesses are given to a malicious process.
|
||||
*
|
||||
* Inputs:
|
||||
* remoteRank: the rank of the remote process
|
||||
* tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be
|
||||
* used to identify the connection inside a GPU kernel.
|
||||
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
|
||||
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
|
||||
*/
|
||||
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport);
|
||||
|
||||
/* Establish all connections created by mscclppConnect(). This function must be called after all mscclppConnect()
|
||||
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
|
||||
*/
|
||||
void connectionSetup();
|
||||
|
||||
/* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */
|
||||
void startProxying();
|
||||
/* Add a custom Setuppable object to a list of objects to be setup later, when setup() is called. */
|
||||
void addSetup(std::shared_ptr<Setuppable> setuppable);
|
||||
|
||||
/* Stop proxy thread(s). */
|
||||
void stopProxying();
|
||||
|
||||
/* Return the rank of the calling process.
|
||||
*
|
||||
* Outputs:
|
||||
* rank: the rank of the calling process
|
||||
*/
|
||||
int rank();
|
||||
|
||||
/* Return the number of ranks of the communicator.
|
||||
*
|
||||
* Outputs:
|
||||
* size: the number of ranks of the communicator
|
||||
*/
|
||||
int size();
|
||||
/* Setup all objects that have registered for setup. This includes any connections created by connect(). */
|
||||
void setup();
|
||||
|
||||
struct Impl;
|
||||
private:
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
friend class HostConnection;
|
||||
};
|
||||
|
||||
enum class ProxyHandlerResult {
|
||||
Continue,
|
||||
FlushFifoTailAndContinue,
|
||||
Stop,
|
||||
};
|
||||
|
||||
class Proxy;
|
||||
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
|
||||
|
||||
class Proxy {
|
||||
public:
|
||||
Proxy(ProxyHandler handler);
|
||||
|
||||
~Proxy();
|
||||
|
||||
void start();
|
||||
|
||||
void stop();
|
||||
|
||||
HostProxyFifo& fifo();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
namespace std {
|
||||
template <> struct hash<mscclpp::TransportFlags>
|
||||
{
|
||||
size_t operator()(const mscclpp::TransportFlags& flags) const
|
||||
{
|
||||
return hash<mscclpp::detail::TransportFlagsBase>()(flags.toBitset());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
#endif // MSCCLPP_H_
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
#ifndef MSCCLPPFIFO_HPP_
|
||||
#define MSCCLPPFIFO_HPP_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct alignas(16) ProxyTrigger {
|
||||
// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered.
|
||||
// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem.
|
||||
#define MSCCLPP_PROXY_FIFO_SIZE 128
|
||||
#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4
|
||||
|
||||
struct alignas(16) ProxyTrigger
|
||||
{
|
||||
uint64_t fst, snd;
|
||||
};
|
||||
|
||||
@@ -24,7 +30,8 @@ struct alignas(16) ProxyTrigger {
|
||||
* Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates
|
||||
* for the tail as there is usually enough space for device threads to push their work into.
|
||||
*/
|
||||
struct DeviceProxyFifo {
|
||||
struct DeviceProxyFifo
|
||||
{
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger)
|
||||
{
|
||||
@@ -34,32 +41,31 @@ struct DeviceProxyFifo {
|
||||
while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0)
|
||||
;
|
||||
ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]);
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr),
|
||||
"l"(trigger.fst), "l"(trigger.snd));
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
|
||||
return curFifoHead;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements
|
||||
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* head; // Allocated on device. Only accessed by device
|
||||
uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused
|
||||
// occasionally to device
|
||||
uint64_t* head; // Allocated on device. Only accessed by device
|
||||
};
|
||||
|
||||
class HostProxyFifo
|
||||
{
|
||||
public:
|
||||
HostProxyFifo();
|
||||
|
||||
|
||||
~HostProxyFifo();
|
||||
|
||||
void poll(ProxyTrigger *trigger);
|
||||
|
||||
void poll(ProxyTrigger* trigger);
|
||||
|
||||
void pop();
|
||||
|
||||
|
||||
void flushTail(bool sync = false);
|
||||
|
||||
DeviceProxyFifo toDevice();
|
||||
DeviceProxyFifo deviceFifo();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
|
||||
@@ -59,8 +59,8 @@ struct mscclppProxyState
|
||||
mscclppProxyRunState_t run;
|
||||
|
||||
int numaNodeToBind;
|
||||
struct mscclppIbContext* ibContext; // For IB connection only
|
||||
cudaStream_t p2pStream; // for P2P DMA engine only
|
||||
mscclpp::IbCtx* ibContext; // For IB connection only
|
||||
cudaStream_t p2pStream; // for P2P DMA engine only
|
||||
|
||||
struct mscclppProxyFifo fifo;
|
||||
};
|
||||
|
||||
40
src/include/proxy.hpp
Normal file
40
src/include/proxy.hpp
Normal file
@@ -0,0 +1,40 @@
|
||||
#ifndef MSCCLPP_PROXY_HPP_
|
||||
#define MSCCLPP_PROXY_HPP_
|
||||
|
||||
#include "mscclppfifo.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
enum class ProxyHandlerResult
|
||||
{
|
||||
Continue,
|
||||
FlushFifoTailAndContinue,
|
||||
Stop,
|
||||
};
|
||||
|
||||
class Proxy;
|
||||
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
|
||||
|
||||
class Proxy
|
||||
{
|
||||
public:
|
||||
Proxy(ProxyHandler handler, std::function<void()> threadInit);
|
||||
Proxy(ProxyHandler handler);
|
||||
~Proxy();
|
||||
|
||||
void start();
|
||||
void stop();
|
||||
|
||||
HostProxyFifo& fifo();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_PROXY_HPP_
|
||||
55
src/include/registered_memory.hpp
Normal file
55
src/include/registered_memory.hpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
#define MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
|
||||
#include "communicator.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct TransportInfo
|
||||
{
|
||||
Transport transport;
|
||||
|
||||
// TODO: rewrite this using std::variant or something
|
||||
bool ibLocal;
|
||||
union {
|
||||
struct {
|
||||
cudaIpcMemHandle_t cudaIpcBaseHandle;
|
||||
size_t cudaIpcOffsetFromBase;
|
||||
};
|
||||
struct {
|
||||
const IbMr* ibMr;
|
||||
IbMrInfo ibMrInfo;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct RegisteredMemory::Impl
|
||||
{
|
||||
void* data;
|
||||
size_t size;
|
||||
int rank;
|
||||
uint64_t hostHash;
|
||||
TransportFlags transports;
|
||||
std::vector<TransportInfo> transportInfos;
|
||||
|
||||
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
|
||||
Impl(const std::vector<char>& data);
|
||||
|
||||
TransportInfo& getTransportInfo(Transport transport)
|
||||
{
|
||||
for (auto& entry : transportInfos) {
|
||||
if (entry.transport == transport) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Transport data not found");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_REGISTERED_MEMORY_HPP_
|
||||
52
src/include/registered_ptr.hpp
Normal file
52
src/include/registered_ptr.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#ifndef MSCCLPP_REGISTERED_PTR_HPP_
|
||||
#define MSCCLPP_REGISTERED_PTR_HPP_
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
template <typename T> class RegisteredPtr
|
||||
{
|
||||
RegisteredMemory memory;
|
||||
size_t offset;
|
||||
|
||||
public:
|
||||
RegisteredPtr(RegisteredMemory memory, size_t offset) : memory(memory), offset(offset)
|
||||
{
|
||||
}
|
||||
RegisteredPtr(RegisteredMemory memory) : RegisteredPtr(memory, 0)
|
||||
{
|
||||
}
|
||||
~RegisteredPtr()
|
||||
{
|
||||
}
|
||||
|
||||
RegisteredMemory memory()
|
||||
{
|
||||
return memory;
|
||||
}
|
||||
|
||||
T* data()
|
||||
{
|
||||
return reinterpret_cast<T*>(memory.data());
|
||||
}
|
||||
|
||||
size_t size()
|
||||
{
|
||||
return memory.size() / sizeof(T);
|
||||
}
|
||||
|
||||
size_t offset()
|
||||
{
|
||||
return offset;
|
||||
}
|
||||
|
||||
RegisteredPtr<T> operator+(size_t offset)
|
||||
{
|
||||
return RegisteredPtr<T>(memory, this->offset + offset);
|
||||
}
|
||||
|
||||
// TODO: all other relevant overloads
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_REGISTERED_PTR_HPP_
|
||||
54
src/include/utils.hpp
Normal file
54
src/include/utils.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
#ifndef MSCCLPP_UTILS_HPP_
|
||||
#define MSCCLPP_UTILS_HPP_
|
||||
|
||||
#include <chrono>
|
||||
#include <stdio.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
struct Timer
|
||||
{
|
||||
std::chrono::steady_clock::time_point start;
|
||||
|
||||
Timer()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
int64_t elapsed()
|
||||
{
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
}
|
||||
|
||||
void reset()
|
||||
{
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void print(const char* name)
|
||||
{
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
printf("%s: %ld us\n", name, elapsed);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScopedTimer
|
||||
{
|
||||
Timer timer;
|
||||
const char* name;
|
||||
|
||||
ScopedTimer(const char* name) : name(name)
|
||||
{
|
||||
}
|
||||
|
||||
~ScopedTimer()
|
||||
{
|
||||
timer.print(name);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_UTILS_HPP_
|
||||
124
src/init.cc
124
src/init.cc
@@ -6,6 +6,7 @@
|
||||
#if defined(MSCCLPP_USE_GDRCOPY)
|
||||
#include "gdr.h"
|
||||
#endif
|
||||
#include "infiniband/verbs.h"
|
||||
#include "mscclpp.h"
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
@@ -191,7 +192,7 @@ MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
|
||||
|
||||
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
|
||||
if (comm->ibContext[i]) {
|
||||
MSCCLPPCHECK(mscclppIbContextDestroy(comm->ibContext[i]));
|
||||
comm->ibContext[i].reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,7 +327,8 @@ struct mscclppHostP2PConn : mscclppHostConn
|
||||
{
|
||||
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
|
||||
}
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
|
||||
uint64_t dataSize)
|
||||
{
|
||||
void* srcBuff = (void*)((char*)conn->bufferRegistrations[src].data + srcDataOffset);
|
||||
void* dstBuff = (void*)((char*)conn->remoteBufferRegistrations[dst].data + dstDataOffset);
|
||||
@@ -364,26 +366,20 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
{
|
||||
put(1, dstDataOffset, 1, srcDataOffset, dataSize);
|
||||
}
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset,
|
||||
uint64_t dataSize)
|
||||
{
|
||||
this->ibQp->stageSend(this->ibMrs[src], &this->remoteIbMrInfos[dst], (uint32_t)dataSize,
|
||||
this->ibQp->stageSend(this->ibMrs[src], this->remoteIbMrInfos[dst], (uint32_t)dataSize,
|
||||
/*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false);
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
}
|
||||
this->ibQp->postSend();
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize);
|
||||
}
|
||||
void signal()
|
||||
{
|
||||
// My local device flag is copied to the remote's proxy flag
|
||||
this->ibQp->stageSend(this->ibMrs[0], &this->remoteIbMrInfos[0], sizeof(uint64_t),
|
||||
this->ibQp->stageSend(this->ibMrs[0], this->remoteIbMrInfos[0], sizeof(uint64_t),
|
||||
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
WARN("flag postSend failed: errno %d", ret);
|
||||
}
|
||||
this->ibQp->postSend();
|
||||
npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t));
|
||||
}
|
||||
void wait()
|
||||
@@ -399,15 +395,11 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc* wc = &this->ibQp->wcs[i];
|
||||
struct ibv_wc* wc = (struct ibv_wc*)this->ibQp->getWc(i);
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
continue;
|
||||
}
|
||||
if (wc->qp_num != this->ibQp->qp->qp_num) {
|
||||
WARN("got wc of unknown qp_num %d", wc->qp_num);
|
||||
continue;
|
||||
}
|
||||
if (wc->opcode == IBV_WC_RDMA_WRITE) {
|
||||
isWaiting = false;
|
||||
break;
|
||||
@@ -418,12 +410,13 @@ struct mscclppHostIBConn : mscclppHostConn
|
||||
}
|
||||
|
||||
mscclppConn* conn;
|
||||
struct mscclppIbQp* ibQp;
|
||||
std::vector<mscclppIbMr*> ibMrs;
|
||||
std::vector<mscclppIbMrInfo> remoteIbMrInfos;
|
||||
mscclpp::IbQp* ibQp;
|
||||
std::vector<const mscclpp::IbMr*> ibMrs;
|
||||
std::vector<mscclpp::IbMrInfo> remoteIbMrInfos;
|
||||
};
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev)
|
||||
MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag,
|
||||
mscclppTransport_t transportType, const char* ibDev)
|
||||
{
|
||||
// save this processes numa binding and set it to the one closest to the device
|
||||
// so that all the allocation are close to the device
|
||||
@@ -458,7 +451,7 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
|
||||
if (firstNullIdx == -1) {
|
||||
firstNullIdx = i;
|
||||
}
|
||||
} else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
} else if (strncmp(comm->ibContext[i]->getDevName().c_str(), ibDev, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
ibDevIdx = i;
|
||||
break;
|
||||
}
|
||||
@@ -468,13 +461,10 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
|
||||
if (ibDevIdx == -1) {
|
||||
// Create a new context.
|
||||
ibDevIdx = firstNullIdx;
|
||||
if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) {
|
||||
WARN("Failed to create IB context");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
comm->ibContext[ibDevIdx].reset(new mscclpp::IbCtx(std::string(ibDev)));
|
||||
}
|
||||
// Set the ib context for this conn
|
||||
conn->ibCtx = comm->ibContext[ibDevIdx];
|
||||
conn->ibCtx = comm->ibContext[ibDevIdx].get();
|
||||
|
||||
} else if (transportType == mscclppTransportP2P) {
|
||||
// do the rest of the initialization later
|
||||
@@ -563,7 +553,8 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int
|
||||
MSCCLPPCHECK(setNumaState(curProcessState));
|
||||
|
||||
mscclppBufferHandle_t signalHandle = -1;
|
||||
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, sizeof(mscclppDevConnSignalEpochId), &signalHandle));
|
||||
MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId,
|
||||
sizeof(mscclppDevConnSignalEpochId), &signalHandle));
|
||||
if (signalHandle != 0) {
|
||||
WARN("signal handle should be 0");
|
||||
return mscclppInternalError;
|
||||
@@ -592,7 +583,9 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle) {
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff,
|
||||
uint64_t buffSize, mscclppBufferHandle_t* handle)
|
||||
{
|
||||
if (connIdx >= comm->nConns) {
|
||||
WARN("connIdx out of range");
|
||||
return mscclppInvalidArgument;
|
||||
@@ -609,35 +602,40 @@ MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t com
|
||||
struct mscclppBufferRegistrationInfo
|
||||
{
|
||||
cudaIpcMemHandle_t cudaHandle;
|
||||
mscclppIbMrInfo ibMrInfo;
|
||||
mscclpp::IbMrInfo ibMrInfo;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct connInfo
|
||||
{
|
||||
mscclppIbQpInfo infoQp;
|
||||
mscclpp::IbQpInfo infoQp;
|
||||
std::vector<mscclppBufferRegistrationInfo> bufferInfos;
|
||||
|
||||
struct header {
|
||||
mscclppIbQpInfo infoQp;
|
||||
struct header
|
||||
{
|
||||
mscclpp::IbQpInfo infoQp;
|
||||
int numBufferInfos;
|
||||
};
|
||||
|
||||
mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag) {
|
||||
mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag)
|
||||
{
|
||||
header h;
|
||||
h.infoQp = infoQp;
|
||||
h.numBufferInfos = bufferInfos.size();
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header)));
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(),
|
||||
bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) {
|
||||
mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag)
|
||||
{
|
||||
header h;
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, &h, sizeof(header)));
|
||||
infoQp = h.infoQp;
|
||||
bufferInfos.resize(h.numBufferInfos);
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(),
|
||||
bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo)));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
};
|
||||
@@ -650,7 +648,7 @@ mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*input
|
||||
}
|
||||
|
||||
// Add all registered buffers
|
||||
for (const auto &bufReg : conn->bufferRegistrations) {
|
||||
for (const auto& bufReg : conn->bufferRegistrations) {
|
||||
connInfo->bufferInfos.emplace_back();
|
||||
CUDACHECK(cudaIpcGetMemHandle(&connInfo->bufferInfos.back().cudaHandle, bufReg.data));
|
||||
connInfo->bufferInfos.back().size = bufReg.size;
|
||||
@@ -672,7 +670,8 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/
|
||||
// Open all remote registered buffers
|
||||
for (size_t i = 0; i < connInfo->bufferInfos.size(); i++) {
|
||||
mscclppBufferRegistration newBufReg;
|
||||
CUDACHECK(cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
CUDACHECK(
|
||||
cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
newBufReg.size = connInfo->bufferInfos[i].size;
|
||||
conn->remoteBufferRegistrations.push_back(newBufReg);
|
||||
}
|
||||
@@ -683,8 +682,8 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/
|
||||
}
|
||||
conn->devConn->remoteSignalEpochId = (mscclppDevConnSignalEpochId*)conn->remoteBufferRegistrations[0].data;
|
||||
|
||||
// For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote buffer
|
||||
// to the first remote data buffer
|
||||
// For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote
|
||||
// buffer to the first remote data buffer
|
||||
if (conn->remoteBufferRegistrations.size() > 1) {
|
||||
conn->devConn->remoteBuff = conn->remoteBufferRegistrations[1].data;
|
||||
}
|
||||
@@ -702,22 +701,20 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
|
||||
devConn->remoteBuff = NULL;
|
||||
devConn->remoteSignalEpochId = NULL;
|
||||
|
||||
struct mscclppIbContext* ibCtx = conn->ibCtx;
|
||||
mscclpp::IbCtx* ibCtx = conn->ibCtx;
|
||||
if (hostConn->ibQp == NULL) {
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp));
|
||||
hostConn->ibQp = ibCtx->createQp();
|
||||
}
|
||||
|
||||
// Add all registered buffers
|
||||
for (const auto &bufReg : conn->bufferRegistrations) {
|
||||
hostConn->ibMrs.emplace_back();
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, bufReg.data,
|
||||
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibMrs.back()));
|
||||
for (const auto& bufReg : conn->bufferRegistrations) {
|
||||
hostConn->ibMrs.emplace_back(ibCtx->registerMr(bufReg.data, sizeof(struct mscclppDevConnSignalEpochId)));
|
||||
connInfo->bufferInfos.emplace_back();
|
||||
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->info;
|
||||
connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->getInfo();
|
||||
connInfo->bufferInfos.back().size = bufReg.size;
|
||||
}
|
||||
|
||||
connInfo->infoQp = hostConn->ibQp->info;
|
||||
connInfo->infoQp = hostConn->ibQp->getInfo();
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -728,14 +725,8 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
|
||||
return mscclppInternalError;
|
||||
}
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
if (hostConn->ibQp->rtr(&connInfo->infoQp) != 0) {
|
||||
WARN("Failed to transition QP to RTR");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
if (hostConn->ibQp->rts() != 0) {
|
||||
WARN("Failed to transition QP to RTS");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
hostConn->ibQp->rtr(connInfo->infoQp);
|
||||
hostConn->ibQp->rts();
|
||||
|
||||
// No remote pointers to set with IB, so we just set the Mrs
|
||||
|
||||
@@ -764,7 +755,8 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
MSCCLPPCHECK(mscclppIbConnectionSetupStart(&cInfo, conn));
|
||||
}
|
||||
// TODO: from saemal: do we possibly deadlock if there are too many outstanding sends?
|
||||
// MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo)));
|
||||
// MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo,
|
||||
// sizeof(cInfo)));
|
||||
MSCCLPPCHECK(cInfo.sendOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag));
|
||||
}
|
||||
|
||||
@@ -788,25 +780,25 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
struct bufferInfo
|
||||
{
|
||||
cudaIpcMemHandle_t handleBuff;
|
||||
mscclppIbMrInfo infoBuffMr;
|
||||
mscclpp::IbMrInfo infoBuffMr;
|
||||
};
|
||||
|
||||
MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
|
||||
mscclppRegisteredMemory* regMem)
|
||||
{
|
||||
std::vector<struct mscclppIbMr*> ibMrs;
|
||||
std::vector<const mscclpp::IbMr*> ibMrs;
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct bufferInfo bInfo;
|
||||
struct mscclppIbMr* ibBuffMr;
|
||||
const mscclpp::IbMr* ibBuffMr;
|
||||
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
CUDACHECK(cudaIpcGetMemHandle(&bInfo.handleBuff, local_memory));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(conn->ibCtx, local_memory, size, &ibBuffMr));
|
||||
bInfo.infoBuffMr = ibBuffMr->info;
|
||||
ibMrs.push_back(ibBuffMr);
|
||||
ibBuffMr = conn->ibCtx->registerMr(local_memory, size);
|
||||
bInfo.infoBuffMr = ibBuffMr->getInfo();
|
||||
ibMrs.emplace_back(ibBuffMr);
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo)));
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "checks.h"
|
||||
#include "comm.h"
|
||||
#include "debug.h"
|
||||
#include "ib.h"
|
||||
#include "ib.hpp"
|
||||
#include "socket.h"
|
||||
|
||||
#include <emmintrin.h>
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#include "proxy.hpp"
|
||||
#include "api.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "utils.h"
|
||||
#include "api.h"
|
||||
#include <thread>
|
||||
#include "utils.hpp"
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -10,30 +12,41 @@ const int ProxyStopCheckPeriod = 1000;
|
||||
|
||||
const int ProxyFlushPeriod = 4;
|
||||
|
||||
struct Proxy::Impl {
|
||||
struct Proxy::Impl
|
||||
{
|
||||
ProxyHandler handler;
|
||||
std::function<void()> threadInit;
|
||||
HostProxyFifo fifo;
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
Impl(ProxyHandler handler) : handler(handler), running(false) {}
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit) : handler(handler), threadInit(threadInit), running(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) {
|
||||
pimpl = std::make_unique<Impl>(handler);
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit)
|
||||
{
|
||||
pimpl = std::make_unique<Impl>(handler, threadInit);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::~Proxy() {
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {})
|
||||
{
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::~Proxy()
|
||||
{
|
||||
if (pimpl) {
|
||||
stop();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::start() {
|
||||
MSCCLPP_API_CPP void Proxy::start()
|
||||
{
|
||||
pimpl->running = true;
|
||||
pimpl->service = std::thread([this] {
|
||||
// from this point on, proxy thread will stay close to the device
|
||||
// PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this
|
||||
|
||||
pimpl->threadInit();
|
||||
|
||||
ProxyHandler handler = this->pimpl->handler;
|
||||
HostProxyFifo& fifo = this->pimpl->fifo;
|
||||
@@ -52,7 +65,7 @@ MSCCLPP_API_CPP void Proxy::start() {
|
||||
// Poll to see if we are ready to send anything
|
||||
fifo.poll(&trigger);
|
||||
if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers
|
||||
continue; // there is one in progress
|
||||
continue; // there is one in progress
|
||||
}
|
||||
|
||||
ProxyHandlerResult result = handler(trigger);
|
||||
@@ -83,14 +96,16 @@ MSCCLPP_API_CPP void Proxy::start() {
|
||||
});
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Proxy::stop() {
|
||||
MSCCLPP_API_CPP void Proxy::stop()
|
||||
{
|
||||
pimpl->running = false;
|
||||
if (pimpl->service.joinable()) {
|
||||
pimpl->service.join();
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() {
|
||||
MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo()
|
||||
{
|
||||
return pimpl->fifo;
|
||||
}
|
||||
|
||||
|
||||
164
src/registered_memory.cc
Normal file
164
src/registered_memory.cc
Normal file
@@ -0,0 +1,164 @@
|
||||
#include "registered_memory.hpp"
|
||||
#include "api.h"
|
||||
#include "checks.hpp"
|
||||
#include "utils.h"
|
||||
#include <algorithm>
|
||||
#include <cuda.h>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl)
|
||||
: data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports)
|
||||
{
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = Transport::CudaIpc;
|
||||
cudaIpcMemHandle_t handle;
|
||||
|
||||
void* baseDataPtr;
|
||||
size_t baseDataSize; // dummy
|
||||
CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data));
|
||||
CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr));
|
||||
// TODO: bug with offset of base?
|
||||
transportInfo.cudaIpcBaseHandle = handle;
|
||||
transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr;
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if ((transports & AllIBTransports).any()) {
|
||||
auto addIb = [&](Transport ibTransport) {
|
||||
TransportInfo transportInfo;
|
||||
transportInfo.transport = ibTransport;
|
||||
const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size);
|
||||
transportInfo.ibMr = mr;
|
||||
transportInfo.ibLocal = true;
|
||||
transportInfo.ibMrInfo = mr->getInfo();
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
INFO(MSCCLPP_NET, "IB mr for address %p with size %ld is registered", data, size);
|
||||
};
|
||||
if (transports.has(Transport::IB0))
|
||||
addIb(Transport::IB0);
|
||||
if (transports.has(Transport::IB1))
|
||||
addIb(Transport::IB1);
|
||||
if (transports.has(Transport::IB2))
|
||||
addIb(Transport::IB2);
|
||||
if (transports.has(Transport::IB3))
|
||||
addIb(Transport::IB3);
|
||||
if (transports.has(Transport::IB4))
|
||||
addIb(Transport::IB4);
|
||||
if (transports.has(Transport::IB5))
|
||||
addIb(Transport::IB5);
|
||||
if (transports.has(Transport::IB6))
|
||||
addIb(Transport::IB6);
|
||||
if (transports.has(Transport::IB7))
|
||||
addIb(Transport::IB7);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl) : pimpl(pimpl)
|
||||
{
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
|
||||
|
||||
MSCCLPP_API_CPP void* RegisteredMemory::data()
|
||||
{
|
||||
return pimpl->data;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP size_t RegisteredMemory::size()
|
||||
{
|
||||
return pimpl->size;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int RegisteredMemory::rank()
|
||||
{
|
||||
return pimpl->rank;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports()
|
||||
{
|
||||
return pimpl->transports;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
|
||||
{
|
||||
std::vector<char> result;
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
|
||||
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
|
||||
throw std::runtime_error("Too many transport info entries");
|
||||
}
|
||||
int8_t transportCount = pimpl->transportInfos.size();
|
||||
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
|
||||
for (auto& entry : pimpl->transportInfos) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
|
||||
if (entry.transport == Transport::CudaIpc) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle),
|
||||
std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase),
|
||||
std::back_inserter(result));
|
||||
} else if (AllIBTransports.has(entry.transport)) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
|
||||
} else {
|
||||
throw std::runtime_error("Unknown transport");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector<char>& data)
|
||||
{
|
||||
return RegisteredMemory(std::make_shared<Impl>(data));
|
||||
}
|
||||
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
||||
{
|
||||
auto it = serialization.begin();
|
||||
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
|
||||
it += sizeof(this->size);
|
||||
std::copy_n(it, sizeof(this->rank), reinterpret_cast<char*>(&this->rank));
|
||||
it += sizeof(this->rank);
|
||||
std::copy_n(it, sizeof(this->hostHash), reinterpret_cast<char*>(&this->hostHash));
|
||||
it += sizeof(this->hostHash);
|
||||
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
|
||||
it += sizeof(this->transports);
|
||||
int8_t transportCount;
|
||||
std::copy_n(it, sizeof(transportCount), reinterpret_cast<char*>(&transportCount));
|
||||
it += sizeof(transportCount);
|
||||
for (int i = 0; i < transportCount; ++i) {
|
||||
TransportInfo transportInfo;
|
||||
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
|
||||
it += sizeof(transportInfo.transport);
|
||||
if (transportInfo.transport == Transport::CudaIpc) {
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
|
||||
it += sizeof(transportInfo.cudaIpcBaseHandle);
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
|
||||
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
|
||||
} else if (AllIBTransports.has(transportInfo.transport)) {
|
||||
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
|
||||
it += sizeof(transportInfo.ibMrInfo);
|
||||
transportInfo.ibLocal = false;
|
||||
} else {
|
||||
throw std::runtime_error("Unknown transport");
|
||||
}
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if (it != serialization.end()) {
|
||||
throw std::runtime_error("Deserialization failed");
|
||||
}
|
||||
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
uint64_t localHostHash = getHostHash();
|
||||
if (localHostHash == this->hostHash) {
|
||||
auto entry = getTransportInfo(Transport::CudaIpc);
|
||||
void* base;
|
||||
CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
|
||||
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <numa.h>
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
// Get current Compute Capability
|
||||
// int mscclppCudaCompCap() {
|
||||
@@ -112,7 +113,7 @@ uint64_t getHash(const char* string, int n)
|
||||
* This string can be overridden by using the MSCCLPP_HOSTID env var.
|
||||
*/
|
||||
#define HOSTID_FILE "/proc/sys/kernel/random/boot_id"
|
||||
uint64_t getHostHash(void)
|
||||
uint64_t computeHostHash(void)
|
||||
{
|
||||
char hostHash[1024];
|
||||
char* hostId;
|
||||
@@ -144,6 +145,12 @@ uint64_t getHostHash(void)
|
||||
return getHash(hostHash, strlen(hostHash));
|
||||
}
|
||||
|
||||
uint64_t getHostHash(void)
|
||||
{
|
||||
thread_local std::unique_ptr<uint64_t> hostHash = std::make_unique<uint64_t>(computeHostHash());
|
||||
return *hostHash;
|
||||
}
|
||||
|
||||
/* Generate a hash of the unique identifying string for this process
|
||||
* that will be unique for both bare-metal and container instances
|
||||
* Equivalent of a hash of;
|
||||
|
||||
10
tests/CMakeLists.txt
Normal file
10
tests/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
add_executable(bootstrap_test_cpp bootstrap_test_cpp.cc)
|
||||
target_link_libraries(bootstrap_test_cpp mscclpp MPI::MPI_CXX)
|
||||
|
||||
add_executable(communicator_test_cpp communicator_test_cpp.cu)
|
||||
target_link_libraries(communicator_test_cpp mscclpp MPI::MPI_CXX)
|
||||
|
||||
add_executable(allgather_test_cpp allgather_test_cpp.cu)
|
||||
target_link_libraries(allgather_test_cpp mscclpp MPI::MPI_CXX)
|
||||
|
||||
add_subdirectory(unittests)
|
||||
@@ -1,17 +1,18 @@
|
||||
#include "mscclpp.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include "channel.hpp"
|
||||
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
#include "mpi.h"
|
||||
#endif // MSCCLPP_USE_MPI_FOR_TESTS
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <unordered_map>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
|
||||
static int nranksPerNode = 8;
|
||||
|
||||
@@ -48,29 +49,30 @@ static double getTime(void)
|
||||
return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec;
|
||||
}
|
||||
|
||||
__constant__ mscclpp::SimpleDeviceConnection constDevConns[16];
|
||||
__constant__ mscclpp::channel::SimpleDeviceChannel constDevChans[16];
|
||||
|
||||
__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
|
||||
__device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
{
|
||||
// this allgather is really simple and implemented as an alltoall
|
||||
|
||||
// this thread's role is a sender role
|
||||
// put your data asynchronously
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
// make sure everyone is put their data before some thread randomly blocks everyone else in signal
|
||||
__syncthreads();
|
||||
// push with flag and sync to make sure the data is received
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.flush();
|
||||
devChan.flush();
|
||||
|
||||
// this thread's role is a receiver role. wait on the semaphore to make sure the data is ready
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
devChan.wait();
|
||||
}
|
||||
|
||||
__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
uint64_t offset, uint64_t size)
|
||||
__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
|
||||
int remoteRank, uint64_t offset, uint64_t size)
|
||||
{
|
||||
// this allgather algorithm works as follows:
|
||||
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
|
||||
@@ -82,26 +84,29 @@ __device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank
|
||||
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
|
||||
// put your data to GPU (rank+i) % nranksPerNode and signal in one call
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(offset, size);
|
||||
devChan.putWithSignal(offset, size);
|
||||
}
|
||||
// wait for the data from GPU (rank-i) % nranksPerNode to arrive
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
devChan.wait();
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory");
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
__device__ void allgather1(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
|
||||
int remoteRank, size_t nelemsPerGPU)
|
||||
{
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode)
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.flush();
|
||||
}
|
||||
|
||||
__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
__device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode,
|
||||
int remoteRank, size_t nelemsPerGPU)
|
||||
{
|
||||
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
||||
// it is implemented as follows:
|
||||
@@ -118,17 +123,17 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
|
||||
// Step 1
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
// cross-node exchange
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int),
|
||||
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
devChan.wait();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -137,7 +142,7 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
|
||||
// local allgather
|
||||
int otherNghr = (rank + nranksPerNode) % world_size;
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
|
||||
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
}
|
||||
|
||||
@@ -145,11 +150,11 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
|
||||
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
|
||||
sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
devChan.wait();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -157,26 +162,31 @@ __device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, in
|
||||
// Step 3
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank,
|
||||
localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank,
|
||||
(otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
}
|
||||
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode || remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devChan.flush();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel)
|
||||
{
|
||||
// find the mapping between remoteRank and devConns
|
||||
// find the mapping between remoteRank and devChans
|
||||
int warpId = threadIdx.x / 32;
|
||||
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
|
||||
// Each warp is responsible for one of the remote ranks
|
||||
mscclpp::SimpleDeviceConnection devConn = constDevConns[warpId];
|
||||
mscclpp::channel::SimpleDeviceChannel devChan = constDevChans[warpId];
|
||||
|
||||
if (kernel == 0)
|
||||
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
allgather0(devChan, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 1)
|
||||
allgather1(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
allgather1(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 2)
|
||||
allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
allgather2(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
}
|
||||
|
||||
int rankToLocalRank(int rank)
|
||||
@@ -216,40 +226,44 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz
|
||||
CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, int* data_d, size_t dataSize)
|
||||
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize)
|
||||
{
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
std::vector<std::shared_ptr<mscclpp::HostConnection>> hostConns;
|
||||
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
|
||||
std::vector<mscclpp::channel::ChannelId> channelIds;
|
||||
std::vector<mscclpp::RegisteredMemory> localMemories;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank)
|
||||
continue;
|
||||
mscclpp::TransportType transportType;
|
||||
const char* ibDev = ibDevStr.c_str();
|
||||
mscclpp::Transport transport;
|
||||
if (rankToNode(r) == thisNode) {
|
||||
ibDev = NULL;
|
||||
transportType = mscclpp::TransportType::P2P;
|
||||
transport = mscclpp::Transport::CudaIpc;
|
||||
} else {
|
||||
transportType = mscclpp::TransportType::IB;
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
auto hostConn = comm.connect(r, 0, transportType, ibDev);
|
||||
hostConn->registerBuffer(data_d, dataSize);
|
||||
hostConns.push_back(hostConn);
|
||||
channelIds.push_back(channelService.addChannel(comm.connectOnSetup(r, 0, transport)));
|
||||
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
localMemories.push_back(memory);
|
||||
comm.sendMemoryOnSetup(memory, r, 0);
|
||||
remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0));
|
||||
}
|
||||
|
||||
comm.connectionSetup();
|
||||
comm.setup();
|
||||
|
||||
std::vector<mscclpp::SimpleDeviceConnection> devConns;
|
||||
std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns),
|
||||
[](std::shared_ptr<mscclpp::HostConnection>& hostConn) {
|
||||
return mscclpp::SimpleDeviceConnection(*hostConn);
|
||||
});
|
||||
std::vector<mscclpp::channel::SimpleDeviceChannel> devChannels;
|
||||
for (size_t i = 0; i < channelIds.size(); ++i) {
|
||||
devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(channelService.deviceChannel(channelIds[i]),
|
||||
channelService.addMemory(remoteMemories[i].get()), channelService.addMemory(localMemories[i])));
|
||||
}
|
||||
|
||||
assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection));
|
||||
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() ));
|
||||
assert(devChannels.size() < sizeof(constDevChans) / sizeof(mscclpp::channel::SimpleDeviceChannel));
|
||||
CUDACHECK(
|
||||
cudaMemcpyToSymbol(constDevChans, devChannels.data(), sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size()));
|
||||
}
|
||||
|
||||
void printUsage(const char* prog, bool isMpi)
|
||||
@@ -399,22 +413,25 @@ int main(int argc, const char* argv[])
|
||||
}
|
||||
size_t nelemsPerGPU = dataSize / sizeof(int) / world_size;
|
||||
|
||||
try{
|
||||
try {
|
||||
if (rank == 0)
|
||||
printf("Initializing MSCCL++\n");
|
||||
mscclpp::Communicator comm(world_size, ip_port, rank);
|
||||
printf("Initializing MSCCL++\n");
|
||||
auto bootstrapper = std::make_shared<mscclpp::Bootstrap>(rank, world_size);
|
||||
bootstrapper->initialize(ip_port);
|
||||
mscclpp::Communicator comm(bootstrapper);
|
||||
mscclpp::channel::DeviceChannelService channelService(comm);
|
||||
|
||||
if (rank == 0)
|
||||
printf("Initializing data for allgather test\n");
|
||||
printf("Initializing data for allgather test\n");
|
||||
initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d);
|
||||
|
||||
if (rank == 0)
|
||||
printf("Setting up the connection in MSCCL++\n");
|
||||
setupMscclppConnections(rank, world_size, comm, data_d, dataSize);
|
||||
printf("Setting up the connection in MSCCL++\n");
|
||||
setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize);
|
||||
|
||||
if (rank == 0)
|
||||
printf("Launching MSCCL++ proxy threads\n");
|
||||
comm.startProxying();
|
||||
channelService.startProxy();
|
||||
|
||||
if (rank == 0)
|
||||
printf("Testing the correctness of AllGather implementation\n");
|
||||
@@ -434,7 +451,7 @@ int main(int argc, const char* argv[])
|
||||
}
|
||||
int tmp[16];
|
||||
// A simple barrier
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
if (rank == 0)
|
||||
printf("Successfully checked the correctness\n");
|
||||
|
||||
@@ -443,12 +460,12 @@ int main(int argc, const char* argv[])
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
for (int i = 0; i < iterwithoutcudagraph; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
|
||||
// cudaGraph Capture
|
||||
int cudagraphiter = 10;
|
||||
@@ -466,7 +483,7 @@ int main(int argc, const char* argv[])
|
||||
int cudagraphwarmup = 10;
|
||||
if (rank == 0)
|
||||
printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup,
|
||||
cudagraphiter);
|
||||
cudagraphiter);
|
||||
for (int i = 0; i < cudagraphwarmup; ++i) {
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
@@ -476,8 +493,8 @@ int main(int argc, const char* argv[])
|
||||
int cudagraphlaunch = 10;
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch,
|
||||
cudagraphiter);
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
cudagraphiter);
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
double t0, t1, ms, time_in_us;
|
||||
t0 = getTime();
|
||||
for (int i = 0; i < cudagraphlaunch; ++i) {
|
||||
@@ -489,12 +506,12 @@ int main(int argc, const char* argv[])
|
||||
ms = (t1 - t0) * 1000.0;
|
||||
time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter;
|
||||
printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us,
|
||||
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
|
||||
comm.bootstrapAllGather(tmp, sizeof(int));
|
||||
(double)(dataSize) / 1e9 / (time_in_us / 1e6));
|
||||
bootstrapper->allGather(tmp, sizeof(int));
|
||||
|
||||
if (rank == 0)
|
||||
printf("Stopping MSCCL++ proxy threads\n");
|
||||
comm.stopProxying();
|
||||
channelService.stopProxy();
|
||||
|
||||
} catch (std::exception& e) {
|
||||
// todo: throw exceptions in the implementation and process them here
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mpi.h>
|
||||
|
||||
void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
|
||||
{
|
||||
std::vector<int> tmp(bootstrap->getNranks(), 0);
|
||||
tmp[bootstrap->getRank()] = bootstrap->getRank() + 1;
|
||||
bootstrap->allGather(tmp.data(), sizeof(int));
|
||||
@@ -16,15 +17,17 @@ void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
std::cout << "AllGather test passed!" << std::endl;
|
||||
}
|
||||
|
||||
void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
|
||||
{
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Barrier test passed!" << std::endl;
|
||||
}
|
||||
|
||||
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
|
||||
{
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
if (bootstrap->getRank() == 0)
|
||||
if (bootstrap->getRank() == i)
|
||||
continue;
|
||||
int msg1 = (bootstrap->getRank() + 1) * 3;
|
||||
int msg2 = (bootstrap->getRank() + 1) * 3 + 1;
|
||||
@@ -35,7 +38,7 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
}
|
||||
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
if (i == bootstrap->getRank())
|
||||
if (bootstrap->getRank() == i)
|
||||
continue;
|
||||
int msg1 = 0;
|
||||
int msg2 = 0;
|
||||
@@ -52,14 +55,16 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
std::cout << "Send/Recv test passed!" << std::endl;
|
||||
}
|
||||
|
||||
void test_all(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
void test_all(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
|
||||
{
|
||||
test_allgather(bootstrap);
|
||||
test_barrier(bootstrap);
|
||||
// test_sendrecv(bootstrap);
|
||||
test_sendrecv(bootstrap);
|
||||
}
|
||||
|
||||
void test_mscclpp_bootstrap_with_id(int rank, int worldSize){
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
|
||||
void test_mscclpp_bootstrap_with_id(int rank, int worldSize)
|
||||
{
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0)
|
||||
id = bootstrap->createUniqueId();
|
||||
@@ -71,7 +76,8 @@ void test_mscclpp_bootstrap_with_id(int rank, int worldSize){
|
||||
std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar){
|
||||
void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar)
|
||||
{
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
|
||||
bootstrap->initialize(ipPortPiar);
|
||||
|
||||
@@ -80,47 +86,57 @@ void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipP
|
||||
std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
class MPIBootstrap : public mscclpp::BaseBootstrap {
|
||||
class MPIBootstrap : public mscclpp::BaseBootstrap
|
||||
{
|
||||
public:
|
||||
MPIBootstrap() : BaseBootstrap() {}
|
||||
int getRank() override {
|
||||
MPIBootstrap() : BaseBootstrap()
|
||||
{
|
||||
}
|
||||
int getRank() override
|
||||
{
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
return rank;
|
||||
}
|
||||
int getNranks() override {
|
||||
int getNranks() override
|
||||
{
|
||||
int worldSize;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
return worldSize;
|
||||
}
|
||||
void allGather(void *sendbuf, int size) override {
|
||||
void allGather(void* sendbuf, int size) override
|
||||
{
|
||||
MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD);
|
||||
}
|
||||
void barrier() override {
|
||||
void barrier() override
|
||||
{
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
}
|
||||
void send(void *sendbuf, int size, int dest, int tag) override {
|
||||
void send(void* sendbuf, int size, int dest, int tag) override
|
||||
{
|
||||
MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
|
||||
}
|
||||
void recv(void *recvbuf, int size, int source, int tag) override {
|
||||
void recv(void* recvbuf, int size, int source, int tag) override
|
||||
{
|
||||
MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||||
}
|
||||
};
|
||||
|
||||
void test_mpi_bootstrap(){
|
||||
void test_mpi_bootstrap()
|
||||
{
|
||||
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap(new MPIBootstrap());
|
||||
test_all(bootstrap);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MPI Bootstrap test passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int rank, worldSize;
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
if (argc > 2){
|
||||
if (argc > 2) {
|
||||
if (rank == 0)
|
||||
std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl;
|
||||
MPI_Finalize();
|
||||
|
||||
272
tests/communicator_test_cpp.cu
Normal file
272
tests/communicator_test_cpp.cu
Normal file
@@ -0,0 +1,272 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include "epoch.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mpi.h>
|
||||
#include <unordered_map>
|
||||
|
||||
#define CUDATHROW(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
mscclpp::Transport findIb(int localRank)
|
||||
{
|
||||
mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2,
|
||||
mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
|
||||
mscclpp::Transport::IB6, mscclpp::Transport::IB7};
|
||||
return IBs[localRank];
|
||||
}
|
||||
|
||||
void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemory){
|
||||
localMemory = communicator.registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice);
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemory;
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank){
|
||||
communicator.sendMemoryOnSetup(localMemory, i, 0);
|
||||
futureRemoteMemory[i] = communicator.recvMemoryOnSetup(i, 0);
|
||||
}
|
||||
}
|
||||
communicator.setup();
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank){
|
||||
remoteMemory[i] = futureRemoteMemory[i].get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections){
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank){
|
||||
if (i / nRanksPerNode == rank / nRanksPerNode) {
|
||||
connections[i] = communicator.connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
|
||||
} else {
|
||||
connections[i] = communicator.connectOnSetup(i, 0, myIbDevice);
|
||||
}
|
||||
}
|
||||
}
|
||||
communicator.setup();
|
||||
}
|
||||
|
||||
void write_remote(int rank, int worldSize, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank){
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank) {
|
||||
auto& conn = connections.at(i);
|
||||
auto& peerMemory = remoteRegisteredMemories.at(i);
|
||||
conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, rank * dataCountPerRank*sizeof(int), dataCountPerRank*sizeof(int));
|
||||
conn->flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void device_buffer_init(int rank, int worldSize, int dataCount, std::vector<int*>& devicePtr){
|
||||
for (int n = 0; n < (int)devicePtr.size(); n++){
|
||||
std::vector<int> hostBuffer(dataCount, 0);
|
||||
for (int i = 0; i < dataCount; i++) {
|
||||
hostBuffer[i] = rank + n * worldSize;
|
||||
}
|
||||
CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount*sizeof(int), cudaMemcpyHostToDevice));
|
||||
}
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector<int*>& devicePtr){
|
||||
for (int n = 0; n < (int)devicePtr.size(); n++){
|
||||
std::vector<int> hostBuffer(dataCount, 0);
|
||||
CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount*sizeof(int), cudaMemcpyDeviceToHost));
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
for (int j = i*dataCount/worldSize; j < (i+1)*dataCount/worldSize; j++) {
|
||||
if (hostBuffer[j] != i + n * worldSize) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, int numBuffers){
|
||||
|
||||
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
|
||||
size_t dataCount = deviceBufferSize / sizeof(int);
|
||||
|
||||
device_buffer_init(rank, worldSize, dataCount, devicePtr);
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "CUDA memory initialization passed" << std::endl;
|
||||
|
||||
for (int n = 0; n < numBuffers; n++){
|
||||
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
|
||||
}
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "RDMA write for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
|
||||
// polling until it becomes ready
|
||||
bool ready = false;
|
||||
int niter = 0;
|
||||
do {
|
||||
ready = test_device_buffer_write_correctness(worldSize, dataCount, devicePtr);
|
||||
niter++;
|
||||
if (niter == 10000){
|
||||
throw std::runtime_error("Polling is stuck.");
|
||||
}
|
||||
} while (!ready);
|
||||
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
}
|
||||
|
||||
__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize){
|
||||
deviceEpochs[tid].epochIncrement();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){
|
||||
int tid = threadIdx.x;
|
||||
if (tid != rank && tid < worldSize){
|
||||
deviceEpochs[tid].wait();
|
||||
}
|
||||
}
|
||||
|
||||
void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::shared_ptr<mscclpp::BaseBootstrap> bootstrap, std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>>& remoteMemory, std::vector<mscclpp::RegisteredMemory>& localMemory, std::vector<int*>& devicePtr, std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs, int numBuffers){
|
||||
|
||||
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
|
||||
size_t dataCount = deviceBufferSize / sizeof(int);
|
||||
|
||||
device_buffer_init(rank, worldSize, dataCount, devicePtr);
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "CUDA memory initialization passed" << std::endl;
|
||||
|
||||
mscclpp::DeviceEpoch* deviceEpochs;
|
||||
CUDATHROW(cudaMalloc(&deviceEpochs, sizeof(mscclpp::DeviceEpoch) * worldSize));
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
mscclpp::DeviceEpoch deviceEpoch = epochs[i]->deviceEpoch();
|
||||
CUDATHROW(cudaMemcpy(&deviceEpochs[i], &deviceEpoch, sizeof(mscclpp::DeviceEpoch), cudaMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "CUDA device epochs are created" << std::endl;
|
||||
|
||||
|
||||
for (int n = 0; n < numBuffers; n++){
|
||||
write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize);
|
||||
}
|
||||
|
||||
increament_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
epochs[i]->signal();
|
||||
}
|
||||
}
|
||||
|
||||
wait_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize);
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)){
|
||||
throw std::runtime_error("unexpected result.");
|
||||
}
|
||||
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---" << std::endl;
|
||||
}
|
||||
|
||||
void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
{
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0)
|
||||
id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
bootstrap->initialize(id);
|
||||
|
||||
mscclpp::Communicator communicator(bootstrap);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Communicator initialization passed" << std::endl;
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> connections;
|
||||
auto myIbDevice = findIb(rank % nranksPerNode);
|
||||
|
||||
make_connections(communicator, rank, worldSize, nranksPerNode, myIbDevice, connections);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Connection setup passed" << std::endl;
|
||||
|
||||
int numBuffers = 10;
|
||||
std::vector<int*> devicePtr(numBuffers);
|
||||
int deviceBufferSize = 1024*1024;
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> localMemory(numBuffers);
|
||||
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remoteMemory(numBuffers);
|
||||
|
||||
for (int n = 0; n < numBuffers; n++) {
|
||||
if (n % 100 == 0)
|
||||
std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl;
|
||||
CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize));
|
||||
register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]);
|
||||
}
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
|
||||
test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, numBuffers);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- Testing vanialla writes passed ---" << std::endl;
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Epoch>> epochs;
|
||||
for (auto entry : connections) {
|
||||
auto& conn = entry.second;
|
||||
epochs.insert({entry.first, std::make_shared<mscclpp::Epoch>(communicator, conn)});
|
||||
}
|
||||
communicator.setup();
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Epochs are created" << std::endl;
|
||||
|
||||
test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, epochs, numBuffers);
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
|
||||
|
||||
for (int n = 0; n < numBuffers; n++){
|
||||
CUDATHROW(cudaFree(devicePtr[n]));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int rank, worldSize;
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
MPI_Comm shmcomm;
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
|
||||
int shmWorldSize;
|
||||
MPI_Comm_size(shmcomm, &shmWorldSize);
|
||||
int nranksPerNode = shmWorldSize;
|
||||
MPI_Comm_free(&shmcomm);
|
||||
|
||||
test_communicator(rank, worldSize, nranksPerNode);
|
||||
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
2
tests/unittests/CMakeLists.txt
Normal file
2
tests/unittests/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_executable(ib_test ib_test.cc)
|
||||
target_link_libraries(ib_test mscclpp MPI::MPI_CXX CUDA::cudart)
|
||||
@@ -1,7 +1,9 @@
|
||||
#include "alloc.h"
|
||||
#include "checks.h"
|
||||
#include "ib.h"
|
||||
#include <set>
|
||||
#include "ib.hpp"
|
||||
#include "infiniband/verbs.h"
|
||||
#include "mscclpp.hpp"
|
||||
#include <array>
|
||||
#include <string>
|
||||
|
||||
// Measure current time in second.
|
||||
@@ -24,8 +26,8 @@ int main(int argc, const char* argv[])
|
||||
printf("Usage: %s <ip:port> <0(recv)/1(send)> <gpu id> <ib id>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
const char* ip_port = argv[1];
|
||||
int is_send = atoi(argv[2]);
|
||||
const char* ipPortPair = argv[1];
|
||||
int isSend = atoi(argv[2]);
|
||||
int cudaDevId = atoi(argv[3]);
|
||||
std::string ibDevName = "mlx5_ib" + std::string(argv[4]);
|
||||
|
||||
@@ -35,51 +37,40 @@ int main(int argc, const char* argv[])
|
||||
int nelem = 1;
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&data, nelem));
|
||||
|
||||
mscclppComm_t comm;
|
||||
MSCCLPPCHECK(mscclppCommInitRank(&comm, 2, ip_port, is_send));
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(isSend, 2));
|
||||
bootstrap->initialize(ipPortPair);
|
||||
|
||||
struct mscclppIbContext* ctx;
|
||||
struct mscclppIbQp* qp;
|
||||
struct mscclppIbMr* mr;
|
||||
MSCCLPPCHECK(mscclppIbContextCreate(&ctx, ibDevName.c_str()));
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ctx, &qp));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ctx, data, sizeof(int) * nelem, &mr));
|
||||
mscclpp::IbCtx ctx(ibDevName);
|
||||
mscclpp::IbQp* qp = ctx.createQp();
|
||||
const mscclpp::IbMr* mr = ctx.registerMr(data, sizeof(int) * nelem);
|
||||
|
||||
struct mscclppIbQpInfo* qpInfo;
|
||||
MSCCLPPCHECK(mscclppCalloc(&qpInfo, 2));
|
||||
qpInfo[is_send] = qp->info;
|
||||
std::array<mscclpp::IbQpInfo, 2> qpInfo;
|
||||
qpInfo[isSend] = qp->getInfo();
|
||||
|
||||
struct mscclppIbMrInfo* mrInfo;
|
||||
MSCCLPPCHECK(mscclppCalloc(&mrInfo, 2));
|
||||
mrInfo[is_send] = mr->info;
|
||||
std::array<mscclpp::IbMrInfo, 2> mrInfo;
|
||||
mrInfo[isSend] = mr->getInfo();
|
||||
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, qpInfo, sizeof(struct mscclppIbQpInfo)));
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, mrInfo, sizeof(struct mscclppIbMrInfo)));
|
||||
bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo));
|
||||
bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo));
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (i == is_send)
|
||||
for (int i = 0; i < bootstrap->getNranks(); ++i) {
|
||||
if (i == isSend)
|
||||
continue;
|
||||
qp->rtr(&qpInfo[i]);
|
||||
qp->rtr(qpInfo[i]);
|
||||
qp->rts();
|
||||
break;
|
||||
}
|
||||
|
||||
printf("connection succeed\n");
|
||||
|
||||
// A simple barrier
|
||||
int* tmp;
|
||||
MSCCLPPCHECK(mscclppCalloc(&tmp, 2));
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
bootstrap->barrier();
|
||||
|
||||
if (is_send) {
|
||||
if (isSend) {
|
||||
int maxIter = 100000;
|
||||
double start = getTime();
|
||||
for (int iter = 0; iter < maxIter; ++iter) {
|
||||
qp->stageSend(mr, &mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true);
|
||||
if (qp->postSend() != 0) {
|
||||
WARN("postSend failed");
|
||||
return 1;
|
||||
}
|
||||
qp->stageSend(mr, mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true);
|
||||
qp->postSend();
|
||||
bool waiting = true;
|
||||
while (waiting) {
|
||||
int wcNum = qp->pollCq();
|
||||
@@ -88,7 +79,7 @@ int main(int argc, const char* argv[])
|
||||
return 1;
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc* wc = &qp->wcs[i];
|
||||
const struct ibv_wc* wc = reinterpret_cast<const struct ibv_wc*>(qp->getWc(i));
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
return 1;
|
||||
@@ -103,10 +94,7 @@ int main(int argc, const char* argv[])
|
||||
}
|
||||
|
||||
// A simple barrier
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
|
||||
MSCCLPPCHECK(mscclppIbContextDestroy(ctx));
|
||||
MSCCLPPCHECK(mscclppCommDestroy(comm));
|
||||
bootstrap->barrier();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user