mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Test memory_channel_tests with raw cuMem API
This commit is contained in:
@@ -2,9 +2,78 @@
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cuda.h>
|
||||
|
||||
#include "mp_unit_tests.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper struct to manage raw cuMem allocated memory
|
||||
struct CuMemBuffer {
|
||||
void* ptr = nullptr;
|
||||
size_t size = 0;
|
||||
CUmemGenericAllocationHandle handle;
|
||||
|
||||
CuMemBuffer() = default;
|
||||
CuMemBuffer(const CuMemBuffer&) = delete;
|
||||
CuMemBuffer& operator=(const CuMemBuffer&) = delete;
|
||||
|
||||
~CuMemBuffer() {
|
||||
if (ptr) {
|
||||
cuMemUnmap((CUdeviceptr)ptr, size);
|
||||
cuMemRelease(handle);
|
||||
cuMemAddressFree((CUdeviceptr)ptr, size);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate GPU memory using raw cuMem API (cuMemCreate, cuMemAddressReserve, cuMemMap)
|
||||
std::shared_ptr<CuMemBuffer> allocateCuMemBuffer(size_t bytes) {
|
||||
auto buffer = std::make_shared<CuMemBuffer>();
|
||||
|
||||
int deviceId = -1;
|
||||
CUdevice currentDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));
|
||||
MSCCLPP_CUTHROW(cuDeviceGet(¤tDevice, deviceId));
|
||||
|
||||
// Get allocation granularity
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = currentDevice;
|
||||
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
|
||||
size_t gran = 0;
|
||||
MSCCLPP_CUTHROW(cuMemGetAllocationGranularity(&gran, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
|
||||
// Round up to granularity
|
||||
size_t alignedBytes = (bytes + gran - 1) / gran * gran;
|
||||
buffer->size = alignedBytes;
|
||||
|
||||
// Create physical memory
|
||||
MSCCLPP_CUTHROW(cuMemCreate(&buffer->handle, alignedBytes, &prop, 0));
|
||||
|
||||
// Reserve virtual address space
|
||||
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&buffer->ptr, alignedBytes, gran, 0U, 0));
|
||||
|
||||
// Map the physical memory to the virtual address
|
||||
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)buffer->ptr, alignedBytes, 0, buffer->handle, 0));
|
||||
|
||||
// Set memory access
|
||||
CUmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.location.id = currentDevice;
|
||||
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)buffer->ptr, alignedBytes, &accessDesc, 1));
|
||||
|
||||
// Zero out the memory
|
||||
MSCCLPP_CUDATHROW(cudaMemset(buffer->ptr, 0, alignedBytes));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void MemoryChannelOneToOneTest::SetUp() {
|
||||
// Need at least two ranks within a node
|
||||
if (gEnv->nRanksPerNode < 2) {
|
||||
@@ -74,9 +143,12 @@ void MemoryChannelOneToOneTest::packetPingPongTest(const std::string testName,
|
||||
const int defaultNTries = 1000;
|
||||
|
||||
std::vector<mscclpp::MemoryChannel> memoryChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::GpuBuffer<int>(nElem).memory();
|
||||
std::shared_ptr<int> intermBuff = mscclpp::GpuBuffer<int>(nElem * 2).memory();
|
||||
setupMeshConnections(memoryChannels, buff.get(), nElem * sizeof(int), intermBuff.get(), nElem * 2 * sizeof(int));
|
||||
// Use raw cuMem API instead of mscclpp::GpuBuffer
|
||||
auto buffCuMem = allocateCuMemBuffer(nElem * sizeof(int));
|
||||
auto intermBuffCuMem = allocateCuMemBuffer(nElem * 2 * sizeof(int));
|
||||
int* buff = static_cast<int*>(buffCuMem->ptr);
|
||||
int* intermBuff = static_cast<int*>(intermBuffCuMem->ptr);
|
||||
setupMeshConnections(memoryChannels, buff, nElem * sizeof(int), intermBuff, nElem * 2 * sizeof(int));
|
||||
std::vector<DeviceHandle<mscclpp::MemoryChannel>> deviceHandles(memoryChannels.size());
|
||||
std::transform(memoryChannels.begin(), memoryChannels.end(), deviceHandles.begin(),
|
||||
[](const mscclpp::MemoryChannel& memChan) { return mscclpp::deviceHandle(memChan); });
|
||||
@@ -88,23 +160,23 @@ void MemoryChannelOneToOneTest::packetPingPongTest(const std::string testName,
|
||||
std::shared_ptr<int> ret = mscclpp::detail::gpuCallocHostShared<int>();
|
||||
|
||||
// The least nelem is 2 for packet ping pong
|
||||
kernelWrapper(buff.get(), gEnv->rank, 2, ret.get(), defaultNTries);
|
||||
kernelWrapper(buff, gEnv->rank, 2, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
*ret = 0;
|
||||
|
||||
kernelWrapper(buff.get(), gEnv->rank, 1024, ret.get(), defaultNTries);
|
||||
kernelWrapper(buff, gEnv->rank, 1024, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelWrapper(buff.get(), gEnv->rank, 1024 * 1024, ret.get(), defaultNTries);
|
||||
kernelWrapper(buff, gEnv->rank, 1024 * 1024, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelWrapper(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get(), defaultNTries);
|
||||
kernelWrapper(buff, gEnv->rank, 4 * 1024 * 1024, ret.get(), defaultNTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
@@ -113,7 +185,7 @@ void MemoryChannelOneToOneTest::packetPingPongTest(const std::string testName,
|
||||
int nTries = 1000000;
|
||||
communicator->bootstrap()->barrier();
|
||||
mscclpp::Timer timer;
|
||||
kernelWrapper(buff.get(), gEnv->rank, 1024, ret.get(), nTries);
|
||||
kernelWrapper(buff, gEnv->rank, 1024, ret.get(), nTries);
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
@@ -175,8 +247,10 @@ TEST_F(MemoryChannelOneToOneTest, PutPingPong) {
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<mscclpp::MemoryChannel> memoryChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::GpuBuffer<int>(nElem).memory();
|
||||
setupMeshConnections(memoryChannels, buff.get(), nElem * sizeof(int));
|
||||
// Use raw cuMem API instead of mscclpp::GpuBuffer
|
||||
auto buffCuMem = allocateCuMemBuffer(nElem * sizeof(int));
|
||||
int* buff = static_cast<int*>(buffCuMem->ptr);
|
||||
setupMeshConnections(memoryChannels, buff, nElem * sizeof(int));
|
||||
std::vector<DeviceHandle<mscclpp::MemoryChannel>> deviceHandles(memoryChannels.size());
|
||||
std::transform(memoryChannels.begin(), memoryChannels.end(), deviceHandles.begin(),
|
||||
[](const mscclpp::MemoryChannel& memChan) { return mscclpp::deviceHandle(memChan); });
|
||||
@@ -187,25 +261,25 @@ TEST_F(MemoryChannelOneToOneTest, PutPingPong) {
|
||||
|
||||
std::shared_ptr<int> ret = mscclpp::detail::gpuCallocHostShared<int>();
|
||||
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, ret.get());
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff, gEnv->rank, 1, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get());
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff, gEnv->rank, 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, ret.get());
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff, gEnv->rank, 1024 * 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get());
|
||||
kernelMemPutPingPong<<<1, 1024>>>(buff, gEnv->rank, 4 * 1024 * 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
@@ -254,8 +328,10 @@ TEST_F(MemoryChannelOneToOneTest, GetPingPong) {
|
||||
const int nElem = 4 * 1024 * 1024;
|
||||
|
||||
std::vector<mscclpp::MemoryChannel> memoryChannels;
|
||||
std::shared_ptr<int> buff = mscclpp::GpuBuffer<int>(nElem).memory();
|
||||
setupMeshConnections(memoryChannels, buff.get(), nElem * sizeof(int));
|
||||
// Use raw cuMem API instead of mscclpp::GpuBuffer
|
||||
auto buffCuMem = allocateCuMemBuffer(nElem * sizeof(int));
|
||||
int* buff = static_cast<int*>(buffCuMem->ptr);
|
||||
setupMeshConnections(memoryChannels, buff, nElem * sizeof(int));
|
||||
std::vector<DeviceHandle<mscclpp::MemoryChannel>> deviceHandles(memoryChannels.size());
|
||||
std::transform(memoryChannels.begin(), memoryChannels.end(), deviceHandles.begin(),
|
||||
[](const mscclpp::MemoryChannel& memChan) { return mscclpp::deviceHandle(memChan); });
|
||||
@@ -266,25 +342,25 @@ TEST_F(MemoryChannelOneToOneTest, GetPingPong) {
|
||||
|
||||
std::shared_ptr<int> ret = mscclpp::detail::gpuCallocHostShared<int>();
|
||||
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, ret.get());
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff, gEnv->rank, 1, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get());
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff, gEnv->rank, 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, ret.get());
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff, gEnv->rank, 1024 * 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
*ret = 0;
|
||||
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, ret.get());
|
||||
kernelMemGetPingPong<<<1, 1024>>>(buff, gEnv->rank, 4 * 1024 * 1024, ret.get());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
EXPECT_EQ(*ret, 0);
|
||||
|
||||
Reference in New Issue
Block a user