Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606)

This commit is contained in:
linqunAMD
2025-08-23 05:46:30 +08:00
committed by GitHub
parent 0db21053e6
commit d6e49c5fde
11 changed files with 683 additions and 127 deletions

View File

@@ -41,7 +41,9 @@ struct BlockwiseGemmXdlops_pipeline_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
@@ -74,9 +76,6 @@ struct BlockwiseGemmXdlops_pipeline_base
return 1;
}();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList =
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
@@ -219,6 +218,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
@@ -227,6 +227,7 @@ struct BlockwiseGemmXdlops_pipeline_base
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
#endif
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -625,13 +626,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::WaveSize;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -141,9 +141,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -626,13 +627,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::WaveSize;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -159,6 +159,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__device__ static constexpr auto HotLoopScheduler()
{
#if !defined(__gfx11__) && !defined(__gfx12__)
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
@@ -260,6 +261,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
#endif
}
template <bool HasMainLoop,