mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 2 (#2722)
Update Blockwise and Gridwise files to support both wave32 & wave64. 1. Calculate WaveSize from template parameter, instead of hard code it to 64, some "64" is also replace with WaveSize 2. Move BN0Shuffled and BK0Shuffled to device side. we can't get correct mfma inst info in host side. 3. Update b_thread_offset_n and b_thread_offset_k in gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp for gfx11. in gfx11, input data is duplicated for each 16 threads, it is different with all of others. 4. Modify a1_threadwise_copy in gridwise_batched_*gemm*gemm for gfx11. for gfx11, we need duplicate input and swizzle A if transposeC isn't enabled.
This commit is contained in:
@@ -1384,25 +1384,31 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, true>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
|
||||
constexpr auto GetMfma<f4_t, 32, 32, f4_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
|
||||
constexpr auto GetMfma<f4_t, 16, 16, f4_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1432,48 +1438,84 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, true>()
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, true>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, true>()
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
|
||||
constexpr auto GetMfma<f6_t, 32, 32, f6_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
|
||||
constexpr auto GetMfma<f6_t, 16, 16, f6_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
#endif
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
|
||||
constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
|
||||
constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, is_single_rate_mfma, true>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#else
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1852,7 +1894,7 @@ struct XdlopsGemm
|
||||
Sequence<8>{}));
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegSizePerXdlops()
|
||||
__device__ __host__ static constexpr index_t GetRegSizePerXdlops()
|
||||
{
|
||||
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
|
||||
}
|
||||
@@ -1961,7 +2003,7 @@ struct XdlopsGemm
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
#if defined(__gfx11__)
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<true>();
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
|
||||
#else
|
||||
const auto blk_idx = GetBlkIdx();
|
||||
#endif
|
||||
@@ -1983,7 +2025,7 @@ struct XdlopsGemm
|
||||
{
|
||||
const auto laneId = GetLaneId();
|
||||
#if defined(__gfx11__)
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<false>();
|
||||
const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
|
||||
#else
|
||||
const auto blk_idx = GetBlkIdx();
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user