mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
aad conn write test
This commit is contained in:
@@ -21,6 +21,7 @@ Communicator::Impl::Impl(std::shared_ptr<BaseBootstrap> bootstrap) : bootstrap_(
|
||||
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
|
||||
rankToHash_[bootstrap->getRank()] = hostHash;
|
||||
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
|
||||
comm->rank = bootstrap->getRank();
|
||||
}
|
||||
|
||||
Communicator::Impl::~Impl()
|
||||
|
||||
@@ -13,8 +13,8 @@ namespace mscclpp {
|
||||
class ConnectionBase : public Connection
|
||||
{
|
||||
public:
|
||||
virtual void startSetup(std::shared_ptr<BaseBootstrap> bootstrap){};
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap> bootstrap){};
|
||||
virtual void startSetup(std::shared_ptr<BaseBootstrap>){};
|
||||
virtual void endSetup(std::shared_ptr<BaseBootstrap>){};
|
||||
};
|
||||
|
||||
class CudaIpcConnection : public ConnectionBase
|
||||
|
||||
@@ -35,14 +35,17 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Communicator initialization passed" << std::endl;
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
auto myIbDevice = findIb(rank % nranksPerNode);
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank) {
|
||||
std::shared_ptr<mscclpp::Connection> conn;
|
||||
if (i / nranksPerNode == rank / nranksPerNode) {
|
||||
auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
|
||||
conn = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
|
||||
} else {
|
||||
auto connect = communicator->connect(i, 0, myIbDevice);
|
||||
conn = communicator->connect(i, 0, myIbDevice);
|
||||
}
|
||||
connections.push_back(conn);
|
||||
}
|
||||
}
|
||||
communicator->connectionSetup();
|
||||
@@ -63,20 +66,52 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
bootstrap->send(serialized.data(), serializedSize, i, 1);
|
||||
}
|
||||
}
|
||||
std::vector<mscclpp::RegisteredMemory> registeredMemories;
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank){
|
||||
int deserializedSize;
|
||||
bootstrap->recv(&deserializedSize, sizeof(int), i, 0);
|
||||
std::vector<char> deserialized(deserializedSize);
|
||||
bootstrap->recv(deserialized.data(), deserializedSize, i, 1);
|
||||
// auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized);
|
||||
auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized);
|
||||
registeredMemories.push_back(std::move(deserializedRegisteredMemory));
|
||||
}
|
||||
}
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Memory registration passed" << std::endl;
|
||||
|
||||
assert(size % worldSize == 0);
|
||||
size_t writeSize = size / worldSize;
|
||||
size_t dataCount = size / sizeof(int);
|
||||
// std::vector<int> hostBuffer(dataCount, 0);
|
||||
std::shared_ptr<int[]> hostBuffer(new int[dataCount]);
|
||||
for (int i = 0; i < dataCount; i++) {
|
||||
hostBuffer[i] = rank;
|
||||
}
|
||||
CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice));
|
||||
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i != rank) {
|
||||
int peerRankIndex = i < rank ? i : i - 1;
|
||||
auto conn = connections[peerRankIndex];
|
||||
conn->write(registeredMemories[peerRankIndex], rank * writeSize, registeredMemory, rank * writeSize, writeSize);
|
||||
}
|
||||
}
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
CUDATHROW(cudaMemcpy(hostBuffer.get(), devicePtr, size, cudaMemcpyDeviceToHost));
|
||||
size_t dataPerRank = writeSize / sizeof(int);
|
||||
for (int i = 0; i < dataCount; i++) {
|
||||
if (hostBuffer[i] != i / dataPerRank) {
|
||||
throw std::runtime_error("Data mismatch, connection write failed");
|
||||
}
|
||||
}
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Memory registeration passed" << std::endl;
|
||||
std::cout << "Connection write passed" << std::endl;
|
||||
|
||||
CUDATHROW(cudaFree(devicePtr));
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user