mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Allow binding allocated memory to NVLS multicast pointer (#290)
And change NVLS multimem instructions to static functions
This commit is contained in:
@@ -484,6 +484,11 @@ class NvlsConnection {
|
||||
};
|
||||
|
||||
std::shared_ptr<DeviceMulticastPointer> allocateAndBindCuda(size_t size);
|
||||
|
||||
/// The \p handle to the allocation (its lifetime is managed by the caller)
|
||||
/// and the \p size of the allocation.
|
||||
std::shared_ptr<char> bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size);
|
||||
|
||||
size_t getMultiCastMinGranularity();
|
||||
|
||||
private:
|
||||
|
||||
@@ -22,15 +22,15 @@ struct DeviceMulticastPointerDeviceHandle {
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
template <int NElemPerThread = 4, typename TVaule = float4, typename T = float>
|
||||
MSCCLPP_DEVICE_INLINE void multimemLoad(TVaule& val, T* ptr) {
|
||||
MSCCLPP_DEVICE_INLINE static void multimemLoad(TVaule& val, T* ptr) {
|
||||
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
asm("multimem.ld_reduce.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else if constexpr (std::is_same<T, half2>::value) {
|
||||
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
@@ -40,15 +40,15 @@ struct DeviceMulticastPointerDeviceHandle {
|
||||
};
|
||||
|
||||
template <int NElemPerThread = 4, typename TVaule, typename T>
|
||||
MSCCLPP_DEVICE_INLINE void multimemStore(const TVaule& val, T* ptr) {
|
||||
MSCCLPP_DEVICE_INLINE static void multimemStore(const TVaule& val, T* ptr) {
|
||||
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z),
|
||||
"r"(val.w)
|
||||
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
|
||||
"r"(val.z), "r"(val.w)
|
||||
: "memory");
|
||||
} else if constexpr (std::is_same<T, half2>::value) {
|
||||
asm volatile("multimem.st.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z),
|
||||
"r"(val.w)
|
||||
asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
|
||||
"r"(val.z), "r"(val.w)
|
||||
: "memory");
|
||||
} else {
|
||||
static_assert(dependentFalse<T>, "Not supported type");
|
||||
|
||||
@@ -816,8 +816,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
nvlsPtrs.multimemLoad(val, mc_ptr + idx);
|
||||
nvlsPtrs.multimemStore(val, mc_ptr + idx);
|
||||
DeviceMulticastPointerDeviceHandle::multimemLoad(val, mc_ptr + idx);
|
||||
DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
@@ -41,8 +41,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
nvlsPtrs.multimemLoad(val, mc_ptr + idx);
|
||||
nvlsPtrs.multimemStore(val, mc_ptr + idx);
|
||||
DeviceMulticastPointerDeviceHandle::multimemLoad(val, mc_ptr + idx);
|
||||
DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
@@ -29,7 +29,7 @@ class NvlsConnection::Impl : public std::enable_shared_from_this<NvlsConnection:
|
||||
void addDevice(int cudaDeviceId);
|
||||
size_t allocateBuffer(size_t size);
|
||||
void freeBuffer(size_t offset, size_t size) noexcept;
|
||||
std::shared_ptr<char> bindMemory(std::shared_ptr<PhysicalCudaMemory<char>> physicalMem, size_t devBuffSize);
|
||||
std::shared_ptr<char> bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize);
|
||||
|
||||
private:
|
||||
friend class NvlsConnection;
|
||||
@@ -185,11 +185,9 @@ void NvlsConnection::Impl::freeBuffer(size_t offset, size_t size) noexcept {
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<char> NvlsConnection::Impl::bindMemory(std::shared_ptr<PhysicalCudaMemory<char>> physicalMem,
|
||||
size_t devBuffSize) {
|
||||
std::shared_ptr<char> NvlsConnection::Impl::bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize) {
|
||||
size_t offset = allocateBuffer(devBuffSize);
|
||||
MSCCLPP_CUTHROW(
|
||||
cuMulticastBindMem(mcHandle_, offset /*mcOffset*/, physicalMem->memHandle_, 0 /*memOffset*/, devBuffSize, 0));
|
||||
MSCCLPP_CUTHROW(cuMulticastBindMem(mcHandle_, offset /*mcOffset*/, memHandle, 0 /*memOffset*/, devBuffSize, 0));
|
||||
|
||||
char* mcPtr;
|
||||
|
||||
@@ -227,7 +225,7 @@ class NvlsConnection::Impl {
|
||||
std::vector<char> serialize() { throw notSupportedError; }
|
||||
size_t allocateBuffer(size_t) { throw notSupportedError; }
|
||||
void freeBuffer(size_t, size_t) { throw notSupportedError; }
|
||||
std::shared_ptr<char> bindMemory(std::shared_ptr<PhysicalCudaMemory<char>>, size_t) { throw notSupportedError; }
|
||||
std::shared_ptr<char> bindMemory(CUmemGenericAllocationHandle, size_t) { throw notSupportedError; }
|
||||
void addDevice(int) { throw notSupportedError; }
|
||||
size_t getMinMcGran() { throw notSupportedError; }
|
||||
|
||||
@@ -253,10 +251,14 @@ std::vector<char> NvlsConnection::serialize() { return pimpl_->serialize(); }
|
||||
|
||||
std::shared_ptr<NvlsConnection::DeviceMulticastPointer> NvlsConnection::allocateAndBindCuda(size_t size) {
|
||||
auto mem = allocSharedPhysicalCuda<char>(size, pimpl_->getMinMcGran());
|
||||
auto mcPtr = pimpl_->bindMemory(mem, size);
|
||||
auto mcPtr = pimpl_->bindMemory(mem->memHandle_, size);
|
||||
return std::make_shared<DeviceMulticastPointer>(mem, mcPtr, size);
|
||||
}
|
||||
|
||||
std::shared_ptr<char> NvlsConnection::bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size) {
|
||||
return pimpl_->bindMemory(memHandle, size);
|
||||
}
|
||||
|
||||
NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() {
|
||||
NvlsConnection::DeviceMulticastPointer::DeviceHandle device;
|
||||
device.devicePtr = this->deviceMem_->devicePtr_;
|
||||
|
||||
Reference in New Issue
Block a user