mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Fix bootstrapping mechanism (#278)
Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: Pashupati Kumar <74680231+pash-msft@users.noreply.github.com>
This commit is contained in:
@@ -101,7 +101,7 @@ find_package(Threads REQUIRED)
|
||||
|
||||
add_library(mscclpp_obj OBJECT)
|
||||
target_include_directories(mscclpp_obj
|
||||
PRIVATE
|
||||
SYSTEM PRIVATE
|
||||
${GPU_INCLUDE_DIRS}
|
||||
${IBVERBS_INCLUDE_DIRS}
|
||||
${NUMA_INCLUDE_DIRS})
|
||||
|
||||
@@ -51,6 +51,10 @@ class Bootstrap {
|
||||
/// A native implementation of the bootstrap using TCP sockets.
|
||||
class TcpBootstrap : public Bootstrap {
|
||||
public:
|
||||
/// Create a random unique ID.
|
||||
/// @return The created unique ID.
|
||||
static UniqueId createUniqueId();
|
||||
|
||||
/// Constructor.
|
||||
/// @param rank The rank of the process.
|
||||
/// @param nRanks The total number of ranks.
|
||||
@@ -59,10 +63,6 @@ class TcpBootstrap : public Bootstrap {
|
||||
/// Destructor.
|
||||
~TcpBootstrap();
|
||||
|
||||
/// Create a random unique ID and store it in the @ref TcpBootstrap.
|
||||
/// @return The created unique ID.
|
||||
UniqueId createUniqueId();
|
||||
|
||||
/// Return the unique ID stored in the @ref TcpBootstrap.
|
||||
/// @return The unique ID stored in the @ref TcpBootstrap.
|
||||
UniqueId getUniqueId() const;
|
||||
|
||||
@@ -9,6 +9,6 @@ FetchContent_MakeAvailable(nanobind)
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
|
||||
nanobind_add_module(mscclpp_py ${SOURCES})
|
||||
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
target_link_libraries(mscclpp_py PRIVATE ${GPU_LIBRARIES} mscclpp_static)
|
||||
target_include_directories(mscclpp_py PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp_py PRIVATE mscclpp_static ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
|
||||
@@ -63,7 +63,7 @@ void register_core(nb::module_& m) {
|
||||
.def_static(
|
||||
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
|
||||
nb::arg("nRanks"))
|
||||
.def("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def_static("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
||||
.def("initialize", static_cast<void (TcpBootstrap::*)(UniqueId, int64_t)>(&TcpBootstrap::initialize),
|
||||
nb::call_guard<nb::gil_scoped_release>(), nb::arg("uniqueId"), nb::arg("timeoutSec") = 30)
|
||||
|
||||
@@ -9,5 +9,5 @@ FetchContent_MakeAvailable(nanobind)
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
|
||||
nanobind_add_module(mscclpp_py_test ${SOURCES})
|
||||
set_target_properties(mscclpp_py_test PROPERTIES OUTPUT_NAME _ext)
|
||||
target_link_libraries(mscclpp_py_test PRIVATE ${GPU_LIBRARIES} mscclpp_static)
|
||||
target_include_directories(mscclpp_py_test PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp_py_test PRIVATE mscclpp_static ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py_test SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
|
||||
@@ -70,12 +70,14 @@ static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is
|
||||
|
||||
class TcpBootstrap::Impl {
|
||||
public:
|
||||
static UniqueId createUniqueId();
|
||||
static UniqueId getUniqueId(const UniqueIdInternal& uniqueId);
|
||||
|
||||
Impl(int rank, int nRanks);
|
||||
~Impl();
|
||||
void initialize(const UniqueId& uniqueId, int64_t timeoutSec);
|
||||
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec);
|
||||
void establishConnections(int64_t timeoutSec);
|
||||
UniqueId createUniqueId();
|
||||
UniqueId getUniqueId() const;
|
||||
int getRank();
|
||||
int getNranks();
|
||||
@@ -99,7 +101,6 @@ class TcpBootstrap::Impl {
|
||||
std::unique_ptr<uint32_t> abortFlagStorage_;
|
||||
volatile uint32_t* abortFlag_;
|
||||
std::thread rootThread_;
|
||||
char netIfName_[MAX_IF_NAME_SIZE + 1];
|
||||
SocketAddress netIfAddr_;
|
||||
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerSendSockets_;
|
||||
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerRecvSockets_;
|
||||
@@ -110,15 +111,33 @@ class TcpBootstrap::Impl {
|
||||
std::shared_ptr<Socket> getPeerSendSocket(int peer, int tag);
|
||||
std::shared_ptr<Socket> getPeerRecvSocket(int peer, int tag);
|
||||
|
||||
static void assignPortToUniqueId(UniqueIdInternal& uniqueId);
|
||||
static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr);
|
||||
|
||||
void bootstrapCreateRoot();
|
||||
void bootstrapRoot();
|
||||
void getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
|
||||
std::vector<SocketAddress>& rankAddressesRoot, int& rank);
|
||||
void sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
|
||||
const std::vector<SocketAddress>& rankAddressesRoot);
|
||||
void netInit(std::string ipPortPair, std::string interface);
|
||||
};
|
||||
|
||||
UniqueId TcpBootstrap::Impl::createUniqueId() {
|
||||
UniqueIdInternal uniqueId;
|
||||
SocketAddress netIfAddr;
|
||||
netInit("", "", netIfAddr);
|
||||
getRandomData(&uniqueId.magic, sizeof(uniqueId_.magic));
|
||||
std::memcpy(&uniqueId.addr, &netIfAddr, sizeof(SocketAddress));
|
||||
assignPortToUniqueId(uniqueId);
|
||||
return getUniqueId(uniqueId);
|
||||
}
|
||||
|
||||
UniqueId TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) {
|
||||
UniqueId ret;
|
||||
std::memcpy(&ret, &uniqueId, sizeof(uniqueId));
|
||||
return ret;
|
||||
}
|
||||
|
||||
TcpBootstrap::Impl::Impl(int rank, int nRanks)
|
||||
: rank_(rank),
|
||||
nRanks_(nRanks),
|
||||
@@ -128,29 +147,26 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks)
|
||||
abortFlagStorage_(new uint32_t(0)),
|
||||
abortFlag_(abortFlagStorage_.get()) {}
|
||||
|
||||
UniqueId TcpBootstrap::Impl::getUniqueId() const {
|
||||
UniqueId ret;
|
||||
std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_));
|
||||
return ret;
|
||||
}
|
||||
|
||||
UniqueId TcpBootstrap::Impl::createUniqueId() {
|
||||
netInit("", "");
|
||||
getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
|
||||
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
|
||||
bootstrapCreateRoot();
|
||||
return getUniqueId();
|
||||
}
|
||||
UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); }
|
||||
|
||||
int TcpBootstrap::Impl::getRank() { return rank_; }
|
||||
|
||||
int TcpBootstrap::Impl::getNranks() { return nRanks_; }
|
||||
|
||||
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) {
|
||||
netInit("", "");
|
||||
if (!netInitialized) {
|
||||
netInit("", "", netIfAddr_);
|
||||
netInitialized = true;
|
||||
}
|
||||
|
||||
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
|
||||
if (rank_ == 0) {
|
||||
bootstrapCreateRoot();
|
||||
}
|
||||
|
||||
char line[MAX_IF_NAME_SIZE + 1];
|
||||
SocketToString(&uniqueId_.addr, line);
|
||||
INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line);
|
||||
establishConnections(timeoutSec);
|
||||
}
|
||||
|
||||
@@ -170,7 +186,10 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim
|
||||
ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1);
|
||||
}
|
||||
|
||||
netInit(ipPortPair, interface);
|
||||
if (!netInitialized) {
|
||||
netInit(ipPortPair, interface, netIfAddr_);
|
||||
netInitialized = true;
|
||||
}
|
||||
|
||||
uniqueId_.magic = 0xdeadbeef;
|
||||
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
|
||||
@@ -230,9 +249,15 @@ void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddr
|
||||
netSend(&sock, &rankAddresses[next], sizeof(SocketAddress));
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::assignPortToUniqueId(UniqueIdInternal& uniqueId) {
|
||||
std::unique_ptr<Socket> socket = std::make_unique<Socket>(&uniqueId.addr, uniqueId.magic, SocketTypeBootstrap);
|
||||
socket->bind();
|
||||
uniqueId.addr = socket->getAddr();
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::bootstrapCreateRoot() {
|
||||
listenSockRoot_ = std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0);
|
||||
listenSockRoot_->listen();
|
||||
listenSockRoot_->bindAndListen();
|
||||
uniqueId_.addr = listenSockRoot_->getAddr();
|
||||
|
||||
rootThread_ = std::thread([this]() {
|
||||
@@ -279,34 +304,33 @@ void TcpBootstrap::Impl::bootstrapRoot() {
|
||||
TRACE(MSCCLPP_INIT, "DONE");
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
|
||||
if (netInitialized) return;
|
||||
void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr) {
|
||||
char netIfName[MAX_IF_NAME_SIZE + 1];
|
||||
if (!ipPortPair.empty()) {
|
||||
if (interface != "") {
|
||||
// we know the <interface>
|
||||
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1, interface.c_str());
|
||||
int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1, interface.c_str());
|
||||
if (ret <= 0) throw Error("NET/Socket : No interface named " + interface + " found.", ErrorCode::InternalError);
|
||||
} else {
|
||||
// we do not know the <interface> try to match it next
|
||||
SocketAddress remoteAddr;
|
||||
SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str());
|
||||
if (FindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
|
||||
if (FindInterfaceMatchSubnet(netIfName, &netIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
|
||||
throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
|
||||
int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1);
|
||||
if (ret <= 0) {
|
||||
throw Error("TcpBootstrap : no socket interface found", ErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
|
||||
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
|
||||
std::sprintf(line, " %s:", netIfName_);
|
||||
SocketToString(&netIfAddr_, line + strlen(line));
|
||||
std::sprintf(line, " %s:", netIfName);
|
||||
SocketToString(&netIfAddr, line + strlen(line));
|
||||
INFO(MSCCLPP_INIT, "TcpBootstrap : Using%s", line);
|
||||
netInitialized = true;
|
||||
}
|
||||
|
||||
#define TIMEOUT(__exp) \
|
||||
@@ -345,13 +369,13 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) {
|
||||
uint64_t magic = uniqueId_.magic;
|
||||
// Create socket for other ranks to contact me
|
||||
listenSock_ = std::make_unique<Socket>(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
|
||||
listenSock_->listen();
|
||||
listenSock_->bindAndListen();
|
||||
info.extAddressListen = listenSock_->getAddr();
|
||||
|
||||
{
|
||||
// Create socket for root to contact me
|
||||
Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
|
||||
lsock.listen();
|
||||
lsock.bindAndListen();
|
||||
info.extAddressListenRoot = lsock.getAddr();
|
||||
|
||||
// stagger connection times to avoid an overload of the root
|
||||
@@ -486,9 +510,9 @@ void TcpBootstrap::Impl::close() {
|
||||
peerRecvSockets_.clear();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
|
||||
MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); }
|
||||
|
||||
MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return pimpl_->createUniqueId(); }
|
||||
MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
|
||||
|
||||
MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
|
||||
|
||||
|
||||
@@ -390,7 +390,7 @@ Socket::Socket(const SocketAddress* addr, uint64_t magic, enum SocketType type,
|
||||
|
||||
Socket::~Socket() { close(); }
|
||||
|
||||
void Socket::listen() {
|
||||
void Socket::bind() {
|
||||
if (fd_ == -1) {
|
||||
throw Error("file descriptor is -1", ErrorCode::InvalidUsage);
|
||||
}
|
||||
@@ -433,7 +433,11 @@ void Socket::listen() {
|
||||
if (::getsockname(fd_, &addr_.sa, &size) != 0) {
|
||||
throw SysError("getsockname failed", errno);
|
||||
}
|
||||
state_ = SocketStateBound;
|
||||
}
|
||||
|
||||
void Socket::bindAndListen() {
|
||||
bind();
|
||||
#ifdef ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Listening on socket %s", SocketToString(&addr_, line));
|
||||
|
||||
@@ -56,6 +56,7 @@ MSCCLPP_API_CPP void Fifo::pop() {
|
||||
MSCCLPP_API_CPP void Fifo::flushTail(bool sync) {
|
||||
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can
|
||||
// make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
|
||||
AvoidCudaGraphCaptureGuard cgcGuard;
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t),
|
||||
cudaMemcpyHostToDevice, pimpl->stream));
|
||||
if (sync) {
|
||||
|
||||
@@ -35,10 +35,11 @@ enum SocketState {
|
||||
SocketStateConnecting = 4,
|
||||
SocketStateConnectPolling = 5,
|
||||
SocketStateConnected = 6,
|
||||
SocketStateReady = 7,
|
||||
SocketStateClosed = 8,
|
||||
SocketStateError = 9,
|
||||
SocketStateNum = 10
|
||||
SocketStateBound = 7,
|
||||
SocketStateReady = 8,
|
||||
SocketStateClosed = 9,
|
||||
SocketStateError = 10,
|
||||
SocketStateNum = 11
|
||||
};
|
||||
|
||||
enum SocketType {
|
||||
@@ -62,7 +63,8 @@ class Socket {
|
||||
enum SocketType type = SocketTypeUnknown, volatile uint32_t* abortFlag = nullptr, int asyncFlag = 0);
|
||||
~Socket();
|
||||
|
||||
void listen();
|
||||
void bind();
|
||||
void bindAndListen();
|
||||
void connect(int64_t timeout = -1);
|
||||
void accept(const Socket* listenSocket, int64_t timeout = -1);
|
||||
void send(void* ptr, int size);
|
||||
|
||||
@@ -5,7 +5,7 @@ find_package(MPI)
|
||||
|
||||
set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
|
||||
set(TEST_LIBS_GTEST GTest::gtest_main GTest::gmock_main)
|
||||
set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include ${GPU_INCLUDE_DIRS})
|
||||
set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include)
|
||||
|
||||
if(USE_ROCM)
|
||||
|
||||
@@ -67,7 +67,7 @@ TEST_F(BootstrapTest, ResumeWithId) {
|
||||
// This test may take a few minutes.
|
||||
bootstrapTestTimer.set(300);
|
||||
|
||||
for (int i = 0; i < 3000; ++i) {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
|
||||
|
||||
@@ -17,7 +17,7 @@ TEST(Socket, ListenAndConnect) {
|
||||
ASSERT_NO_THROW(mscclpp::SocketGetAddrFromString(&listenAddr, ipPortPair.c_str()));
|
||||
|
||||
mscclpp::Socket listenSock(&listenAddr);
|
||||
listenSock.listen();
|
||||
listenSock.bindAndListen();
|
||||
|
||||
std::thread clientThread([&listenAddr]() {
|
||||
mscclpp::Socket sock(&listenAddr);
|
||||
|
||||
Reference in New Issue
Block a user