mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 5 (#2725)
* Enable xdl in gfx11 & gfx12
* update cmake file
* fix all instance build (cmake)
* fix batched_gemm_gemm(cmake)
* rebase cmake files
* fix cmake build error
* remve CK_ENABLE_DYNAMIC_WARP_SIZE
* update cmake build error2
* fix gfx11 build
CK_USE_XDL is enabled on gfx11 and gfx12
* fix gfx10 build
* fix gfx11 error
---------
Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
[ROCm/composable_kernel commit: f22740df82]
This commit is contained in:
@@ -68,11 +68,8 @@ inline bool is_gfx11_supported()
|
||||
inline bool is_xdl_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"
|
||||
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|
||||
|| is_gfx12_supported() || is_gfx11_supported()
|
||||
#endif
|
||||
;
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
||||
is_gfx12_supported() || is_gfx11_supported();
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
|
||||
@@ -83,7 +80,6 @@ inline bool is_xdl_wmma_supported()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|
||||
else if(is_gfx12_supported() || is_gfx11_supported())
|
||||
{
|
||||
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
|
||||
@@ -96,7 +92,6 @@ inline bool is_xdl_wmma_supported()
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|
||||
__device__ constexpr index_t get_warp_size()
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -38,16 +37,6 @@ inline __host__ index_t get_warp_size()
|
||||
#endif
|
||||
return 64;
|
||||
}
|
||||
#else
|
||||
__host__ __device__ constexpr index_t get_warp_size()
|
||||
{
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
|
||||
@@ -359,7 +359,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
@@ -369,6 +369,8 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
@@ -381,6 +383,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, x.template get_as<float>()[I2]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, x.template get_as<float>()[I3]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user