Ethernet support (#284)

Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
This commit is contained in:
Changho Hwang
2024-04-25 11:06:43 -07:00
committed by GitHub
parent 89896ff94f
commit d4ede480f4
13 changed files with 341 additions and 52 deletions

View File

@@ -130,25 +130,26 @@ class TcpBootstrap : public Bootstrap {
/// Enumerates the available transport types.
enum class Transport {
Unknown, // Unknown transport type.
CudaIpc, // CUDA IPC transport type.
Nvls, // NVLS transport type.
IB0, // InfiniBand device 0 transport type.
IB1, // InfiniBand device 1 transport type.
IB2, // InfiniBand device 2 transport type.
IB3, // InfiniBand device 3 transport type.
IB4, // InfiniBand device 4 transport type.
IB5, // InfiniBand device 5 transport type.
IB6, // InfiniBand device 6 transport type.
IB7, // InfiniBand device 7 transport type.
NumTransports // The number of transports.
Unknown, // Unknown transport type.
CudaIpc, // CUDA IPC transport type.
Nvls, // NVLS transport type.
IB0, // InfiniBand device 0 transport type.
IB1, // InfiniBand device 1 transport type.
IB2, // InfiniBand device 2 transport type.
IB3, // InfiniBand device 3 transport type.
IB4, // InfiniBand device 4 transport type.
IB5, // InfiniBand device 5 transport type.
IB6, // InfiniBand device 6 transport type.
IB7, // InfiniBand device 7 transport type.
Ethernet, // Ethernet transport type.
NumTransports, // The number of transports.
};
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2",
"IB3", "IB4", "IB5", "IB6", "IB7", "NUM"};
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
namespace detail {
const size_t TransportFlagsSize = 11;
const size_t TransportFlagsSize = 12;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
"TransportFlagsSize must match the number of transports");
/// Bitset for storing transport flags.
@@ -336,6 +337,11 @@ class RegisteredMemory {
/// @return A pointer to the memory block.
void* data() const;
/// Get a pointer to the original memory block.
///
/// @return A pointer to the original memory block.
void* originalDataPtr() const;
/// Get the size of the memory block.
///
/// @return The size of the memory block.

View File

@@ -543,6 +543,38 @@ void Socket::recv(void* ptr, int size) {
socketWait(MSCCLPP_SOCKET_RECV, ptr, size, &offset);
}
void Socket::recvUntilEnd(void* ptr, int size, int* closed) {
int offset = 0;
*closed = 0;
if (state_ != SocketStateReady) {
std::stringstream ss;
ss << "socket state (" << state_ << ") is not ready in recvUntilEnd";
throw Error(ss.str(), ErrorCode::InternalError);
}
int bytes = 0;
char* data = (char*)ptr;
do {
bytes = ::recv(fd_, data + (offset), size - (offset), 0);
if (bytes == 0) {
*closed = 1;
return;
}
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN && state_ != SocketStateClosed) {
throw SysError("recv until end failed", errno);
} else {
bytes = 0;
}
}
(offset) += bytes;
if (abortFlag_ && *abortFlag_ != 0) {
throw Error("aborted", ErrorCode::Aborted);
}
} while (bytes > 0 && (offset) < size);
}
void Socket::close() {
if (fd_ >= 0) ::close(fd_);
state_ = SocketStateClosed;

View File

@@ -5,6 +5,7 @@
#include <mscclpp/utils.hpp>
#include <sstream>
#include <thread>
#include "debug.h"
#include "endpoint.hpp"
@@ -180,4 +181,148 @@ void IBConnection::flush(int64_t timeoutUsec) {
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}
// EthernetConnection
EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize,
uint64_t recvBufferSize)
: abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) {
// Validating Transport Protocol
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
}
// Instanciating Buffers
sendBuffer_.resize(sendBufferSize_);
recvBuffer_.resize(recvBufferSize_);
// Creating Thread to Accept the Connection
auto parameter = (getImpl(localEndpoint)->socket_).get();
std::thread t([this, parameter]() {
recvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
recvSocket_->accept(parameter);
});
// Starting Connection
sendSocket_ = std::make_unique<Socket>(&(getImpl(remoteEndpoint)->socketAddress_), MSCCLPP_SOCKET_MAGIC,
SocketTypeBootstrap, abortFlag_);
sendSocket_->connect();
// Ensure the Connection was Established
t.join();
// Starting Thread to Receive Messages
threadRecvMessages_ = std::thread(&EthernetConnection::recvMessages, this);
INFO(MSCCLPP_NET, "Ethernet connection created");
}
EthernetConnection::~EthernetConnection() {
sendSocket_->close();
recvSocket_->close();
threadRecvMessages_.join();
}
Transport EthernetConnection::transport() { return Transport::Ethernet; }
Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; }
void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
// Validating Transport Protocol
validateTransport(dst, remoteTransport());
validateTransport(src, transport());
// Initializing Variables
char* srcPtr = reinterpret_cast<char*>(src.data()) + srcOffset / sizeof(char);
char* dstPtr = reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset / sizeof(char);
uint64_t sentDataSize = 0;
uint64_t headerSize = 0;
// Copying Meta Data to Send Buffer
char* dstPtrBytes = reinterpret_cast<char*>(&dstPtr);
std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + headerSize / sizeof(char));
headerSize += sizeof(dstPtr);
char* sizeBytes = reinterpret_cast<char*>(&size);
std::copy(sizeBytes, sizeBytes + sizeof(size), sendBuffer_.data() + headerSize / sizeof(char));
headerSize += sizeof(size);
// Getting Data From GPU and Sending Message
while (sentDataSize < size) {
uint64_t dataSize =
std::min(sendBufferSize_ - headerSize / sizeof(char), (size - sentDataSize) / sizeof(char)) * sizeof(char);
uint64_t messageSize = dataSize + headerSize;
mscclpp::memcpyCuda<char>(sendBuffer_.data() + headerSize / sizeof(char),
(char*)srcPtr + (sentDataSize / sizeof(char)), dataSize, cudaMemcpyDeviceToHost);
sendSocket_->send(sendBuffer_.data(), messageSize);
sentDataSize += messageSize;
headerSize = 0;
}
INFO(MSCCLPP_NET, "EthernetConnection write: from %p to %p, size %lu", srcPtr, dstPtr, size);
}
void EthernetConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
// Validating Transport Protocol
validateTransport(dst, remoteTransport());
// Initializing Variables
uint64_t oldValue = *src;
uint64_t* dstPtr = reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset);
uint64_t dataSize = sizeof(uint64_t);
uint64_t messageSize = 0;
*src = newValue;
// Copying Data to Send Buffer
char* dstPtrBytes = reinterpret_cast<char*>(&dstPtr);
std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + messageSize / sizeof(char));
messageSize += sizeof(dstPtr);
char* sizeBytes = reinterpret_cast<char*>(&dataSize);
std::copy(sizeBytes, sizeBytes + sizeof(dataSize), sendBuffer_.data() + messageSize / sizeof(char));
messageSize += sizeof(dataSize);
char* dataBytes = reinterpret_cast<char*>(src);
std::copy(dataBytes, dataBytes + dataSize, sendBuffer_.data() + messageSize / sizeof(char));
messageSize += dataSize;
// Sending Message
sendSocket_->send(sendBuffer_.data(), messageSize);
INFO(MSCCLPP_NET, "EthernetConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
newValue);
}
void EthernetConnection::flush(int64_t timeoutUsec) { INFO(MSCCLPP_NET, "EthernetConnection flushing connection"); }
void EthernetConnection::recvMessages() {
// Declarating Variables
char* ptr;
uint64_t size;
uint64_t recvSize;
int closed = 0;
bool received = true;
// Receiving Messages Until Connection is Closed
while (recvSocket_->getState() != SocketStateClosed) {
// Receiving Data Address
if (closed == 0) recvSocket_->recvUntilEnd(&ptr, sizeof(char*), &closed);
received &= !closed;
// Receiving data size
if (closed == 0) recvSocket_->recvUntilEnd(&size, sizeof(uint64_t), &closed);
received &= !closed;
// Receiving Data and Copying Data yo GPU
recvSize = 0;
while (recvSize < size && closed == 0) {
uint64_t messageSize = std::min(recvBufferSize_, (size - recvSize) / sizeof(char)) * sizeof(char);
recvSocket_->recvUntilEnd(recvBuffer_.data(), messageSize, &closed);
received &= !closed;
if (received)
mscclpp::memcpyCuda<char>((char*)ptr + (recvSize / sizeof(char)), recvBuffer_.data(), messageSize,
cudaMemcpyHostToDevice);
recvSize += messageSize;
}
}
}
} // namespace mscclpp

View File

@@ -49,9 +49,15 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<IBConnection>(localEndpoint, remoteEndpoint, *this);
} else if (localEndpoint.transport() == Transport::Ethernet) {
if (remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<EthernetConnection>(localEndpoint, remoteEndpoint);
} else {
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
}
pimpl_->connections_.push_back(conn);
return conn;
}

View File

@@ -87,7 +87,7 @@ const TransportFlags NoTransports = TransportFlags();
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 |
Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Transport::Ethernet;
void Setuppable::beginSetup(std::shared_ptr<Bootstrap>) {}

View File

@@ -4,6 +4,7 @@
#include "api.h"
#include "context.hpp"
#include "socket.h"
#include "utils_internal.hpp"
namespace mscclpp {
@@ -15,6 +16,16 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
ibQp_ = contextImpl.getIbContext(transport_)
->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend);
ibQpInfo_ = ibQp_->getInfo();
} else if (transport_ == Transport::Ethernet) {
// Configuring Ethernet Interfaces
abortFlag_ = 0;
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1, "");
if (ret <= 0) throw Error("NET/Socket", ErrorCode::InternalError);
// Starting Server Socket
socket_ = std::make_unique<Socket>(&socketAddress_, MSCCLPP_SOCKET_MAGIC, SocketTypeBootstrap, abortFlag_);
socket_->bindAndListen();
socketAddress_ = socket_->getAddr();
}
}
@@ -27,6 +38,10 @@ MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
if (AllIBTransports.has(pimpl_->transport_)) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data));
}
if ((pimpl_->transport_) == Transport::Ethernet) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->socketAddress_), sizeof(pimpl_->socketAddress_),
std::back_inserter(data));
}
return data;
}
@@ -45,6 +60,10 @@ Endpoint::Impl::Impl(const std::vector<char>& serialization) {
std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast<char*>(&ibQpInfo_));
it += sizeof(ibQpInfo_);
}
if (transport_ == Transport::Ethernet) {
std::copy_n(it, sizeof(socketAddress_), reinterpret_cast<char*>(&socketAddress_));
it += sizeof(socketAddress_);
}
}
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<mscclpp::Endpoint::Impl> pimpl) : pimpl_(pimpl) {}

View File

@@ -11,6 +11,7 @@
#include "context.hpp"
#include "ib.hpp"
#include "registered_memory.hpp"
#include "socket.h"
namespace mscclpp {
@@ -53,6 +54,38 @@ class IBConnection : public Connection {
void flush(int64_t timeoutUsec) override;
};
class EthernetConnection : public Connection {
std::unique_ptr<Socket> sendSocket_;
std::unique_ptr<Socket> recvSocket_;
std::thread threadRecvMessages_;
volatile uint32_t* abortFlag_;
const uint64_t sendBufferSize_;
const uint64_t recvBufferSize_;
std::vector<char> sendBuffer_;
std::vector<char> recvBuffer_;
public:
EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize = 256 * 1024 * 1024,
uint64_t recvBufferSize = 256 * 1024 * 1024);
~EthernetConnection();
Transport transport() override;
Transport remoteTransport() override;
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
void flush(int64_t timeoutUsec) override;
private:
void recvMessages();
void sendMessage();
};
} // namespace mscclpp
#endif // MSCCLPP_CONNECTION_HPP_

View File

@@ -8,6 +8,9 @@
#include <vector>
#include "ib.hpp"
#include "socket.h"
#define MAX_IF_NAME_SIZE 16
namespace mscclpp {
@@ -22,6 +25,12 @@ struct Endpoint::Impl {
bool ibLocal_;
IbQp* ibQp_;
IbQpInfo ibQpInfo_;
// The following are only used for Ethernet and are undefined for other transports.
std::unique_ptr<Socket> socket_;
SocketAddress socketAddress_;
volatile uint32_t* abortFlag_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
};
} // namespace mscclpp

View File

@@ -69,6 +69,7 @@ class Socket {
void accept(const Socket* listenSocket, int64_t timeout = -1);
void send(void* ptr, int size);
void recv(void* ptr, int size);
void recvUntilEnd(void* ptr, int size, int* closed);
void close();
int getFd() const { return fd_; }

View File

@@ -62,6 +62,8 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl_->data; }
MSCCLPP_API_CPP void* RegisteredMemory::originalDataPtr() const { return pimpl_->originalDataPtr; }
MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl_->size; }
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl_->transports; }

View File

@@ -42,14 +42,16 @@ void CommunicatorTestBase::TearDown() {
void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; }
void CommunicatorTestBase::connectMesh(bool useIbOnly) {
void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet) {
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
for (int i = 0; i < numRanksToUse; i++) {
if (i != gEnv->rank) {
if ((rankToNode(i) == rankToNode(gEnv->rank)) && !useIbOnly) {
if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) {
connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
} else {
} else if (useIb) {
connectionFutures[i] = communicator->connectOnSetup(i, 0, ibTransport);
} else if (useEthernet) {
connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::Ethernet);
}
}
}
@@ -97,7 +99,7 @@ void CommunicatorTest::SetUp() {
ASSERT_EQ((deviceBufferSize / sizeof(int)) % gEnv->worldSize, 0);
connectMesh();
connectMesh(true, true, false);
devicePtr.resize(numBuffers);
localMemory.resize(numBuffers);
@@ -281,4 +283,4 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
ASSERT_TRUE(testWriteCorrectness());
communicator->bootstrap()->barrier();
}
}

View File

@@ -93,7 +93,7 @@ class CommunicatorTestBase : public MultiProcessTest {
void TearDown() override;
void setNumRanksToUse(int num);
void connectMesh(bool useIbOnly = false);
void connectMesh(bool useIpc = true, bool useIb = true, bool useEthernet = false);
// Register a local memory and receive corresponding remote memories
void registerMemoryPairs(void* buff, size_t buffSize, mscclpp::TransportFlags transport, int tag,
@@ -130,13 +130,21 @@ using DeviceHandle = mscclpp::DeviceHandle<T>;
class ProxyChannelOneToOneTest : public CommunicatorTestBase {
protected:
struct PingPongTestParams {
bool useIPC;
bool useIB;
bool useEthernet;
bool waitWithPoll;
};
void SetUp() override;
void TearDown() override;
void setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels, bool useIbOnly, void* sendBuff,
size_t sendBuffBytes, void* recvBuff = nullptr, size_t recvBuffBytes = 0);
void testPingPong(bool useIbOnly, bool waitWithPoll);
void testPingPongPerf(bool useIbOnly, bool waitWithPoll);
void setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels, bool useIPC, bool useIb,
bool useEthernet, void* sendBuff, size_t sendBuffBytes, void* recvBuff = nullptr,
size_t recvBuffBytes = 0);
void testPingPong(PingPongTestParams params);
void testPingPongPerf(PingPongTestParams params);
void testPacketPingPong(bool useIbOnly);
void testPacketPingPongPerf(bool useIbOnly);

View File

@@ -16,12 +16,16 @@ void ProxyChannelOneToOneTest::SetUp() {
void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels,
bool useIbOnly, void* sendBuff, size_t sendBuffBytes,
void* recvBuff, size_t recvBuffBytes) {
bool useIPC, bool useIb, bool useEthernet, void* sendBuff,
size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) {
const int rank = communicator->bootstrap()->getRank();
const int worldSize = communicator->bootstrap()->getNranks();
const bool isInPlace = (recvBuff == nullptr);
mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport);
mscclpp::TransportFlags transport;
if (useIPC) transport |= mscclpp::Transport::CudaIpc;
if (useIb) transport |= ibTransport;
if (useEthernet) transport |= mscclpp::Transport::Ethernet;
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
@@ -36,10 +40,12 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleP
if (r == rank) {
continue;
}
if ((rankToNode(r) == rankToNode(gEnv->rank)) && !useIbOnly) {
if ((rankToNode(r) == rankToNode(gEnv->rank)) && useIPC) {
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc);
} else {
} else if (useIb) {
connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport);
} else if (useEthernet) {
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::Ethernet);
}
if (isInPlace) {
@@ -145,14 +151,14 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, bool waitWit
}
}
void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
void ProxyChannelOneToOneTest::testPingPong(PingPongTestParams params) {
if (gEnv->rank >= numRanksToUse) return;
const int nElem = 4 * 1024 * 1024;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::shared_ptr<int> buff = mscclpp::allocExtSharedCuda<int>(nElem);
setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int));
setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int));
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
@@ -167,22 +173,22 @@ void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
const int nTries = 1000;
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, waitWithPoll, nTries, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, params.waitWithPoll, nTries, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, waitWithPoll, nTries, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, params.waitWithPoll, nTries, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, waitWithPoll, nTries, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, params.waitWithPoll, nTries, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
EXPECT_EQ(*ret, 0);
@@ -190,14 +196,14 @@ void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
proxyService->stopProxy();
}
void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPoll) {
void ProxyChannelOneToOneTest::testPingPongPerf(PingPongTestParams params) {
if (gEnv->rank >= numRanksToUse) return;
const int nElem = 4 * 1024 * 1024;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::shared_ptr<int> buff = mscclpp::allocExtSharedCuda<int>(nElem);
setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int));
setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int));
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
@@ -212,17 +218,17 @@ void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPol
auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info();
const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name());
const int nTries = 1000000;
const int nTries = 1000;
// Warm-up
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrap()->barrier();
// Measure latency
mscclpp::Timer timer;
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get());
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrap()->barrier();
@@ -234,17 +240,37 @@ void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPol
proxyService->stopProxy();
}
TEST_F(ProxyChannelOneToOneTest, PingPong) { testPingPong(false, false); }
TEST_F(ProxyChannelOneToOneTest, PingPong) {
testPingPong(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = false});
}
TEST_F(ProxyChannelOneToOneTest, PingPongIb) { testPingPong(true, false); }
TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
testPingPong(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false});
}
TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) { testPingPong(false, true); }
TEST_F(ProxyChannelOneToOneTest, PingPongEthernet) {
testPingPong(PingPongTestParams{.useIPC = false, .useIB = false, .useEthernet = true, .waitWithPoll = false});
}
TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) { testPingPong(true, true); }
TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) {
testPingPong(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = true});
}
TEST_F(ProxyChannelOneToOneTest, PingPongPerf) { testPingPongPerf(false, false); }
TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) {
testPingPong(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = true});
}
TEST_F(ProxyChannelOneToOneTest, PingPongPerfIb) { testPingPongPerf(true, false); }
TEST_F(ProxyChannelOneToOneTest, PingPongPerf) {
testPingPongPerf(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = false});
}
TEST_F(ProxyChannelOneToOneTest, PingPongPerfIb) {
testPingPongPerf(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false});
}
TEST_F(ProxyChannelOneToOneTest, PingPongPerfEthernet) {
testPingPongPerf(PingPongTestParams{.useIPC = false, .useIB = false, .useEthernet = true, .waitWithPoll = false});
}
__device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer;
@@ -324,8 +350,8 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
auto putPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
auto getPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
setupMeshConnections(proxyChannels, useIbOnly, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket),
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(),
nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
ASSERT_EQ(proxyChannels.size(), 1);
@@ -391,8 +417,8 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
auto putPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
auto getPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
setupMeshConnections(proxyChannels, useIbOnly, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket),
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(),
nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
ASSERT_EQ(proxyChannels.size(), 1);