mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Ethernet support (#284)
Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
This commit is contained in:
@@ -130,25 +130,26 @@ class TcpBootstrap : public Bootstrap {
|
||||
|
||||
/// Enumerates the available transport types.
|
||||
enum class Transport {
|
||||
Unknown, // Unknown transport type.
|
||||
CudaIpc, // CUDA IPC transport type.
|
||||
Nvls, // NVLS transport type.
|
||||
IB0, // InfiniBand device 0 transport type.
|
||||
IB1, // InfiniBand device 1 transport type.
|
||||
IB2, // InfiniBand device 2 transport type.
|
||||
IB3, // InfiniBand device 3 transport type.
|
||||
IB4, // InfiniBand device 4 transport type.
|
||||
IB5, // InfiniBand device 5 transport type.
|
||||
IB6, // InfiniBand device 6 transport type.
|
||||
IB7, // InfiniBand device 7 transport type.
|
||||
NumTransports // The number of transports.
|
||||
Unknown, // Unknown transport type.
|
||||
CudaIpc, // CUDA IPC transport type.
|
||||
Nvls, // NVLS transport type.
|
||||
IB0, // InfiniBand device 0 transport type.
|
||||
IB1, // InfiniBand device 1 transport type.
|
||||
IB2, // InfiniBand device 2 transport type.
|
||||
IB3, // InfiniBand device 3 transport type.
|
||||
IB4, // InfiniBand device 4 transport type.
|
||||
IB5, // InfiniBand device 5 transport type.
|
||||
IB6, // InfiniBand device 6 transport type.
|
||||
IB7, // InfiniBand device 7 transport type.
|
||||
Ethernet, // Ethernet transport type.
|
||||
NumTransports, // The number of transports.
|
||||
};
|
||||
|
||||
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2",
|
||||
"IB3", "IB4", "IB5", "IB6", "IB7", "NUM"};
|
||||
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
|
||||
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
|
||||
|
||||
namespace detail {
|
||||
const size_t TransportFlagsSize = 11;
|
||||
const size_t TransportFlagsSize = 12;
|
||||
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
|
||||
"TransportFlagsSize must match the number of transports");
|
||||
/// Bitset for storing transport flags.
|
||||
@@ -336,6 +337,11 @@ class RegisteredMemory {
|
||||
/// @return A pointer to the memory block.
|
||||
void* data() const;
|
||||
|
||||
/// Get a pointer to the original memory block.
|
||||
///
|
||||
/// @return A pointer to the original memory block.
|
||||
void* originalDataPtr() const;
|
||||
|
||||
/// Get the size of the memory block.
|
||||
///
|
||||
/// @return The size of the memory block.
|
||||
|
||||
@@ -543,6 +543,38 @@ void Socket::recv(void* ptr, int size) {
|
||||
socketWait(MSCCLPP_SOCKET_RECV, ptr, size, &offset);
|
||||
}
|
||||
|
||||
void Socket::recvUntilEnd(void* ptr, int size, int* closed) {
|
||||
int offset = 0;
|
||||
*closed = 0;
|
||||
if (state_ != SocketStateReady) {
|
||||
std::stringstream ss;
|
||||
ss << "socket state (" << state_ << ") is not ready in recvUntilEnd";
|
||||
throw Error(ss.str(), ErrorCode::InternalError);
|
||||
}
|
||||
|
||||
int bytes = 0;
|
||||
char* data = (char*)ptr;
|
||||
|
||||
do {
|
||||
bytes = ::recv(fd_, data + (offset), size - (offset), 0);
|
||||
if (bytes == 0) {
|
||||
*closed = 1;
|
||||
return;
|
||||
}
|
||||
if (bytes == -1) {
|
||||
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN && state_ != SocketStateClosed) {
|
||||
throw SysError("recv until end failed", errno);
|
||||
} else {
|
||||
bytes = 0;
|
||||
}
|
||||
}
|
||||
(offset) += bytes;
|
||||
if (abortFlag_ && *abortFlag_ != 0) {
|
||||
throw Error("aborted", ErrorCode::Aborted);
|
||||
}
|
||||
} while (bytes > 0 && (offset) < size);
|
||||
}
|
||||
|
||||
void Socket::close() {
|
||||
if (fd_ >= 0) ::close(fd_);
|
||||
state_ = SocketStateClosed;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "debug.h"
|
||||
#include "endpoint.hpp"
|
||||
@@ -180,4 +181,148 @@ void IBConnection::flush(int64_t timeoutUsec) {
|
||||
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
|
||||
}
|
||||
|
||||
// EthernetConnection
|
||||
|
||||
EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize,
|
||||
uint64_t recvBufferSize)
|
||||
: abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) {
|
||||
// Validating Transport Protocol
|
||||
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
|
||||
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
// Instanciating Buffers
|
||||
sendBuffer_.resize(sendBufferSize_);
|
||||
recvBuffer_.resize(recvBufferSize_);
|
||||
|
||||
// Creating Thread to Accept the Connection
|
||||
auto parameter = (getImpl(localEndpoint)->socket_).get();
|
||||
std::thread t([this, parameter]() {
|
||||
recvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
|
||||
recvSocket_->accept(parameter);
|
||||
});
|
||||
|
||||
// Starting Connection
|
||||
sendSocket_ = std::make_unique<Socket>(&(getImpl(remoteEndpoint)->socketAddress_), MSCCLPP_SOCKET_MAGIC,
|
||||
SocketTypeBootstrap, abortFlag_);
|
||||
sendSocket_->connect();
|
||||
|
||||
// Ensure the Connection was Established
|
||||
t.join();
|
||||
|
||||
// Starting Thread to Receive Messages
|
||||
threadRecvMessages_ = std::thread(&EthernetConnection::recvMessages, this);
|
||||
|
||||
INFO(MSCCLPP_NET, "Ethernet connection created");
|
||||
}
|
||||
|
||||
EthernetConnection::~EthernetConnection() {
|
||||
sendSocket_->close();
|
||||
recvSocket_->close();
|
||||
threadRecvMessages_.join();
|
||||
}
|
||||
|
||||
Transport EthernetConnection::transport() { return Transport::Ethernet; }
|
||||
|
||||
Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; }
|
||||
|
||||
void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) {
|
||||
// Validating Transport Protocol
|
||||
validateTransport(dst, remoteTransport());
|
||||
validateTransport(src, transport());
|
||||
|
||||
// Initializing Variables
|
||||
char* srcPtr = reinterpret_cast<char*>(src.data()) + srcOffset / sizeof(char);
|
||||
char* dstPtr = reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset / sizeof(char);
|
||||
uint64_t sentDataSize = 0;
|
||||
uint64_t headerSize = 0;
|
||||
|
||||
// Copying Meta Data to Send Buffer
|
||||
char* dstPtrBytes = reinterpret_cast<char*>(&dstPtr);
|
||||
std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + headerSize / sizeof(char));
|
||||
headerSize += sizeof(dstPtr);
|
||||
char* sizeBytes = reinterpret_cast<char*>(&size);
|
||||
std::copy(sizeBytes, sizeBytes + sizeof(size), sendBuffer_.data() + headerSize / sizeof(char));
|
||||
headerSize += sizeof(size);
|
||||
|
||||
// Getting Data From GPU and Sending Message
|
||||
while (sentDataSize < size) {
|
||||
uint64_t dataSize =
|
||||
std::min(sendBufferSize_ - headerSize / sizeof(char), (size - sentDataSize) / sizeof(char)) * sizeof(char);
|
||||
uint64_t messageSize = dataSize + headerSize;
|
||||
mscclpp::memcpyCuda<char>(sendBuffer_.data() + headerSize / sizeof(char),
|
||||
(char*)srcPtr + (sentDataSize / sizeof(char)), dataSize, cudaMemcpyDeviceToHost);
|
||||
sendSocket_->send(sendBuffer_.data(), messageSize);
|
||||
sentDataSize += messageSize;
|
||||
headerSize = 0;
|
||||
}
|
||||
|
||||
INFO(MSCCLPP_NET, "EthernetConnection write: from %p to %p, size %lu", srcPtr, dstPtr, size);
|
||||
}
|
||||
|
||||
void EthernetConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
|
||||
// Validating Transport Protocol
|
||||
validateTransport(dst, remoteTransport());
|
||||
|
||||
// Initializing Variables
|
||||
uint64_t oldValue = *src;
|
||||
uint64_t* dstPtr = reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset);
|
||||
uint64_t dataSize = sizeof(uint64_t);
|
||||
uint64_t messageSize = 0;
|
||||
*src = newValue;
|
||||
|
||||
// Copying Data to Send Buffer
|
||||
char* dstPtrBytes = reinterpret_cast<char*>(&dstPtr);
|
||||
std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + messageSize / sizeof(char));
|
||||
messageSize += sizeof(dstPtr);
|
||||
char* sizeBytes = reinterpret_cast<char*>(&dataSize);
|
||||
std::copy(sizeBytes, sizeBytes + sizeof(dataSize), sendBuffer_.data() + messageSize / sizeof(char));
|
||||
messageSize += sizeof(dataSize);
|
||||
char* dataBytes = reinterpret_cast<char*>(src);
|
||||
std::copy(dataBytes, dataBytes + dataSize, sendBuffer_.data() + messageSize / sizeof(char));
|
||||
messageSize += dataSize;
|
||||
|
||||
// Sending Message
|
||||
sendSocket_->send(sendBuffer_.data(), messageSize);
|
||||
|
||||
INFO(MSCCLPP_NET, "EthernetConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
|
||||
newValue);
|
||||
}
|
||||
|
||||
void EthernetConnection::flush(int64_t timeoutUsec) { INFO(MSCCLPP_NET, "EthernetConnection flushing connection"); }
|
||||
|
||||
void EthernetConnection::recvMessages() {
|
||||
// Declarating Variables
|
||||
char* ptr;
|
||||
uint64_t size;
|
||||
uint64_t recvSize;
|
||||
int closed = 0;
|
||||
bool received = true;
|
||||
|
||||
// Receiving Messages Until Connection is Closed
|
||||
while (recvSocket_->getState() != SocketStateClosed) {
|
||||
// Receiving Data Address
|
||||
if (closed == 0) recvSocket_->recvUntilEnd(&ptr, sizeof(char*), &closed);
|
||||
received &= !closed;
|
||||
|
||||
// Receiving data size
|
||||
if (closed == 0) recvSocket_->recvUntilEnd(&size, sizeof(uint64_t), &closed);
|
||||
received &= !closed;
|
||||
|
||||
// Receiving Data and Copying Data yo GPU
|
||||
recvSize = 0;
|
||||
while (recvSize < size && closed == 0) {
|
||||
uint64_t messageSize = std::min(recvBufferSize_, (size - recvSize) / sizeof(char)) * sizeof(char);
|
||||
recvSocket_->recvUntilEnd(recvBuffer_.data(), messageSize, &closed);
|
||||
received &= !closed;
|
||||
|
||||
if (received)
|
||||
mscclpp::memcpyCuda<char>((char*)ptr + (recvSize / sizeof(char)), recvBuffer_.data(), messageSize,
|
||||
cudaMemcpyHostToDevice);
|
||||
recvSize += messageSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -49,9 +49,15 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
|
||||
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
conn = std::make_shared<IBConnection>(localEndpoint, remoteEndpoint, *this);
|
||||
} else if (localEndpoint.transport() == Transport::Ethernet) {
|
||||
if (remoteEndpoint.transport() != Transport::Ethernet) {
|
||||
throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
conn = std::make_shared<EthernetConnection>(localEndpoint, remoteEndpoint);
|
||||
} else {
|
||||
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
|
||||
}
|
||||
|
||||
pimpl_->connections_.push_back(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ const TransportFlags NoTransports = TransportFlags();
|
||||
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 |
|
||||
Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;
|
||||
|
||||
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
|
||||
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Transport::Ethernet;
|
||||
|
||||
void Setuppable::beginSetup(std::shared_ptr<Bootstrap>) {}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "api.h"
|
||||
#include "context.hpp"
|
||||
#include "socket.h"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -15,6 +16,16 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
|
||||
ibQp_ = contextImpl.getIbContext(transport_)
|
||||
->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend);
|
||||
ibQpInfo_ = ibQp_->getInfo();
|
||||
} else if (transport_ == Transport::Ethernet) {
|
||||
// Configuring Ethernet Interfaces
|
||||
abortFlag_ = 0;
|
||||
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1, "");
|
||||
if (ret <= 0) throw Error("NET/Socket", ErrorCode::InternalError);
|
||||
|
||||
// Starting Server Socket
|
||||
socket_ = std::make_unique<Socket>(&socketAddress_, MSCCLPP_SOCKET_MAGIC, SocketTypeBootstrap, abortFlag_);
|
||||
socket_->bindAndListen();
|
||||
socketAddress_ = socket_->getAddr();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +38,10 @@ MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
|
||||
if (AllIBTransports.has(pimpl_->transport_)) {
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data));
|
||||
}
|
||||
if ((pimpl_->transport_) == Transport::Ethernet) {
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->socketAddress_), sizeof(pimpl_->socketAddress_),
|
||||
std::back_inserter(data));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -45,6 +60,10 @@ Endpoint::Impl::Impl(const std::vector<char>& serialization) {
|
||||
std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast<char*>(&ibQpInfo_));
|
||||
it += sizeof(ibQpInfo_);
|
||||
}
|
||||
if (transport_ == Transport::Ethernet) {
|
||||
std::copy_n(it, sizeof(socketAddress_), reinterpret_cast<char*>(&socketAddress_));
|
||||
it += sizeof(socketAddress_);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<mscclpp::Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "context.hpp"
|
||||
#include "ib.hpp"
|
||||
#include "registered_memory.hpp"
|
||||
#include "socket.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -53,6 +54,38 @@ class IBConnection : public Connection {
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
};
|
||||
|
||||
class EthernetConnection : public Connection {
|
||||
std::unique_ptr<Socket> sendSocket_;
|
||||
std::unique_ptr<Socket> recvSocket_;
|
||||
std::thread threadRecvMessages_;
|
||||
volatile uint32_t* abortFlag_;
|
||||
const uint64_t sendBufferSize_;
|
||||
const uint64_t recvBufferSize_;
|
||||
std::vector<char> sendBuffer_;
|
||||
std::vector<char> recvBuffer_;
|
||||
|
||||
public:
|
||||
EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize = 256 * 1024 * 1024,
|
||||
uint64_t recvBufferSize = 256 * 1024 * 1024);
|
||||
|
||||
~EthernetConnection();
|
||||
|
||||
Transport transport() override;
|
||||
|
||||
Transport remoteTransport() override;
|
||||
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) override;
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;
|
||||
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
|
||||
private:
|
||||
void recvMessages();
|
||||
|
||||
void sendMessage();
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_CONNECTION_HPP_
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ib.hpp"
|
||||
#include "socket.h"
|
||||
|
||||
#define MAX_IF_NAME_SIZE 16
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -22,6 +25,12 @@ struct Endpoint::Impl {
|
||||
bool ibLocal_;
|
||||
IbQp* ibQp_;
|
||||
IbQpInfo ibQpInfo_;
|
||||
|
||||
// The following are only used for Ethernet and are undefined for other transports.
|
||||
std::unique_ptr<Socket> socket_;
|
||||
SocketAddress socketAddress_;
|
||||
volatile uint32_t* abortFlag_;
|
||||
char netIfName_[MAX_IF_NAME_SIZE + 1];
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -69,6 +69,7 @@ class Socket {
|
||||
void accept(const Socket* listenSocket, int64_t timeout = -1);
|
||||
void send(void* ptr, int size);
|
||||
void recv(void* ptr, int size);
|
||||
void recvUntilEnd(void* ptr, int size, int* closed);
|
||||
void close();
|
||||
|
||||
int getFd() const { return fd_; }
|
||||
|
||||
@@ -62,6 +62,8 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
|
||||
|
||||
MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl_->data; }
|
||||
|
||||
MSCCLPP_API_CPP void* RegisteredMemory::originalDataPtr() const { return pimpl_->originalDataPtr; }
|
||||
|
||||
MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl_->size; }
|
||||
|
||||
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl_->transports; }
|
||||
|
||||
@@ -42,14 +42,16 @@ void CommunicatorTestBase::TearDown() {
|
||||
|
||||
void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; }
|
||||
|
||||
void CommunicatorTestBase::connectMesh(bool useIbOnly) {
|
||||
void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet) {
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
|
||||
for (int i = 0; i < numRanksToUse; i++) {
|
||||
if (i != gEnv->rank) {
|
||||
if ((rankToNode(i) == rankToNode(gEnv->rank)) && !useIbOnly) {
|
||||
if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) {
|
||||
connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
|
||||
} else {
|
||||
} else if (useIb) {
|
||||
connectionFutures[i] = communicator->connectOnSetup(i, 0, ibTransport);
|
||||
} else if (useEthernet) {
|
||||
connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::Ethernet);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -97,7 +99,7 @@ void CommunicatorTest::SetUp() {
|
||||
|
||||
ASSERT_EQ((deviceBufferSize / sizeof(int)) % gEnv->worldSize, 0);
|
||||
|
||||
connectMesh();
|
||||
connectMesh(true, true, false);
|
||||
|
||||
devicePtr.resize(numBuffers);
|
||||
localMemory.resize(numBuffers);
|
||||
@@ -281,4 +283,4 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
|
||||
|
||||
ASSERT_TRUE(testWriteCorrectness());
|
||||
communicator->bootstrap()->barrier();
|
||||
}
|
||||
}
|
||||
@@ -93,7 +93,7 @@ class CommunicatorTestBase : public MultiProcessTest {
|
||||
void TearDown() override;
|
||||
|
||||
void setNumRanksToUse(int num);
|
||||
void connectMesh(bool useIbOnly = false);
|
||||
void connectMesh(bool useIpc = true, bool useIb = true, bool useEthernet = false);
|
||||
|
||||
// Register a local memory and receive corresponding remote memories
|
||||
void registerMemoryPairs(void* buff, size_t buffSize, mscclpp::TransportFlags transport, int tag,
|
||||
@@ -130,13 +130,21 @@ using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
|
||||
class ProxyChannelOneToOneTest : public CommunicatorTestBase {
|
||||
protected:
|
||||
struct PingPongTestParams {
|
||||
bool useIPC;
|
||||
bool useIB;
|
||||
bool useEthernet;
|
||||
bool waitWithPoll;
|
||||
};
|
||||
|
||||
void SetUp() override;
|
||||
void TearDown() override;
|
||||
|
||||
void setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels, bool useIbOnly, void* sendBuff,
|
||||
size_t sendBuffBytes, void* recvBuff = nullptr, size_t recvBuffBytes = 0);
|
||||
void testPingPong(bool useIbOnly, bool waitWithPoll);
|
||||
void testPingPongPerf(bool useIbOnly, bool waitWithPoll);
|
||||
void setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels, bool useIPC, bool useIb,
|
||||
bool useEthernet, void* sendBuff, size_t sendBuffBytes, void* recvBuff = nullptr,
|
||||
size_t recvBuffBytes = 0);
|
||||
void testPingPong(PingPongTestParams params);
|
||||
void testPingPongPerf(PingPongTestParams params);
|
||||
void testPacketPingPong(bool useIbOnly);
|
||||
void testPacketPingPongPerf(bool useIbOnly);
|
||||
|
||||
|
||||
@@ -16,12 +16,16 @@ void ProxyChannelOneToOneTest::SetUp() {
|
||||
void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
|
||||
|
||||
void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels,
|
||||
bool useIbOnly, void* sendBuff, size_t sendBuffBytes,
|
||||
void* recvBuff, size_t recvBuffBytes) {
|
||||
bool useIPC, bool useIb, bool useEthernet, void* sendBuff,
|
||||
size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) {
|
||||
const int rank = communicator->bootstrap()->getRank();
|
||||
const int worldSize = communicator->bootstrap()->getNranks();
|
||||
const bool isInPlace = (recvBuff == nullptr);
|
||||
mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport);
|
||||
mscclpp::TransportFlags transport;
|
||||
|
||||
if (useIPC) transport |= mscclpp::Transport::CudaIpc;
|
||||
if (useIb) transport |= ibTransport;
|
||||
if (useEthernet) transport |= mscclpp::Transport::Ethernet;
|
||||
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
|
||||
@@ -36,10 +40,12 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleP
|
||||
if (r == rank) {
|
||||
continue;
|
||||
}
|
||||
if ((rankToNode(r) == rankToNode(gEnv->rank)) && !useIbOnly) {
|
||||
if ((rankToNode(r) == rankToNode(gEnv->rank)) && useIPC) {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc);
|
||||
} else {
|
||||
} else if (useIb) {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport);
|
||||
} else if (useEthernet) {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::Ethernet);
|
||||
}
|
||||
|
||||
if (isInPlace) {
|
||||
@@ -145,14 +151,14 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, bool waitWit
|
||||
}
|
||||
}
|
||||
|
||||
void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
|
||||
void ProxyChannelOneToOneTest::testPingPong(PingPongTestParams params) {
|
||||
if (gEnv->rank >= numRanksToUse) return;
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocExtSharedCuda<int>(nElem);
|
||||
setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int));
|
||||
setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int));
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
|
||||
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
|
||||
@@ -167,22 +173,22 @@ void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
|
||||
|
||||
const int nTries = 1000;
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, waitWithPoll, nTries, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, params.waitWithPoll, nTries, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, waitWithPoll, nTries, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, params.waitWithPoll, nTries, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, waitWithPoll, nTries, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, params.waitWithPoll, nTries, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
@@ -190,14 +196,14 @@ void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) {
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
|
||||
void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPoll) {
|
||||
void ProxyChannelOneToOneTest::testPingPongPerf(PingPongTestParams params) {
|
||||
if (gEnv->rank >= numRanksToUse) return;
|
||||
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::allocExtSharedCuda<int>(nElem);
|
||||
setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int));
|
||||
setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int));
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannelHandles;
|
||||
for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle());
|
||||
@@ -212,17 +218,17 @@ void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPol
|
||||
|
||||
auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info();
|
||||
const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name());
|
||||
const int nTries = 1000000;
|
||||
const int nTries = 1000;
|
||||
|
||||
// Warm-up
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
// Measure latency
|
||||
mscclpp::Timer timer;
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get());
|
||||
kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
communicator->bootstrap()->barrier();
|
||||
@@ -234,17 +240,37 @@ void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPol
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPong) { testPingPong(false, false); }
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPong) {
|
||||
testPingPong(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = false});
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongIb) { testPingPong(true, false); }
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
|
||||
testPingPong(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false});
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) { testPingPong(false, true); }
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongEthernet) {
|
||||
testPingPong(PingPongTestParams{.useIPC = false, .useIB = false, .useEthernet = true, .waitWithPoll = false});
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) { testPingPong(true, true); }
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) {
|
||||
testPingPong(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = true});
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongPerf) { testPingPongPerf(false, false); }
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) {
|
||||
testPingPong(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = true});
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongPerfIb) { testPingPongPerf(true, false); }
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongPerf) {
|
||||
testPingPongPerf(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = false});
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongPerfIb) {
|
||||
testPingPongPerf(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false});
|
||||
}
|
||||
|
||||
TEST_F(ProxyChannelOneToOneTest, PingPongPerfEthernet) {
|
||||
testPingPongPerf(PingPongTestParams{.useIPC = false, .useIB = false, .useEthernet = true, .waitWithPoll = false});
|
||||
}
|
||||
|
||||
__device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer;
|
||||
|
||||
@@ -324,8 +350,8 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
|
||||
auto putPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
|
||||
auto getPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
|
||||
|
||||
setupMeshConnections(proxyChannels, useIbOnly, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket),
|
||||
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
|
||||
setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(),
|
||||
nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
|
||||
|
||||
ASSERT_EQ(proxyChannels.size(), 1);
|
||||
|
||||
@@ -391,8 +417,8 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
|
||||
auto putPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
|
||||
auto getPacketBuffer = mscclpp::allocExtSharedCuda<mscclpp::LLPacket>(nPacket);
|
||||
|
||||
setupMeshConnections(proxyChannels, useIbOnly, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket),
|
||||
getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
|
||||
setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(),
|
||||
nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket));
|
||||
|
||||
ASSERT_EQ(proxyChannels.size(), 1);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user