mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Do not check value of __HIP_PLATFORM_AMD__ (#240)
According to the [document](https://rocm.docs.amd.com/projects/HIP/en/docs-6.0.0/user_guide/hip_porting_guide.html#compiler-defines-summary), `__HIP_PLATFORM_AMD__` is effective only by definition.
This commit is contained in:
@@ -4,20 +4,20 @@
|
||||
#ifndef MSCCLPP_DEVICE_HPP_
|
||||
#define MSCCLPP_DEVICE_HPP_
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif // defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#endif // defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#if (defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))
|
||||
|
||||
#define MSCCLPP_DEVICE_COMPILE
|
||||
#define MSCCLPP_DEVICE_INLINE __forceinline__ __device__
|
||||
#define MSCCLPP_HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define MSCCLPP_DEVICE_HIP
|
||||
#else // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))
|
||||
#else // !(defined(__HIP_PLATFORM_AMD__)
|
||||
#define MSCCLPP_DEVICE_CUDA
|
||||
#endif // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))
|
||||
#endif // !(defined(__HIP_PLATFORM_AMD__))
|
||||
|
||||
#else // !(defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#ifndef MSCCLPP_GPU_HPP_
|
||||
#define MSCCLPP_GPU_HPP_
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ T* cudaExtCalloc(size_t nelem) {
|
||||
AvoidCudaGraphCaptureGuard cgcGuard;
|
||||
T* ptr;
|
||||
CudaStreamWithFlags stream(cudaStreamNonBlocking);
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
MSCCLPP_CUDATHROW(hipExtMallocWithFlags((void**)&ptr, nelem * sizeof(T), hipDeviceMallocUncached));
|
||||
#else
|
||||
MSCCLPP_CUDATHROW(cudaMalloc(&ptr, nelem * sizeof(T)));
|
||||
|
||||
@@ -74,7 +74,7 @@ __device__ void localAllGather(DeviceHandle<mscclpp::SimpleProxyChannel> proxyCh
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0) proxyChan.wait();
|
||||
}
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
// NOTE: we actually need a group barrier here for better performance, but __syncthreads() is still correct.
|
||||
__syncthreads();
|
||||
#else
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
@@ -65,7 +65,7 @@ __device__ void localAllGather(DeviceHandle<mscclpp::SimpleProxyChannel> proxyCh
|
||||
if ((remoteRank % nRanksPerNode) == ((rank - i + nRanksPerNode) % nRanksPerNode)) {
|
||||
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.wait();
|
||||
}
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
// NOTE: we actually need a group barrier here for better performance, but __syncthreads() is still correct.
|
||||
__syncthreads();
|
||||
#else
|
||||
|
||||
@@ -956,7 +956,7 @@ __global__ void allreduce4(int* buff, int* scratch, int rank, int nRanksPerNode,
|
||||
}
|
||||
|
||||
__global__ void allreduce5(int* buff, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
localReduceScatterSm3(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
localRingAllGatherSm2(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
|
||||
|
||||
Reference in New Issue
Block a user