mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 3 (#2723)
Support Wave32/Wave64 in all XDL Kernels 1. Add following helper function/marocs in device_base.hpp - GET_NXDL_PER_WAVE_IMPL and GetNXdlPerWave2 - INVOKER_RUN_IMPL and INVOKER_RUN3_IMPL - IsValidGemmCompilationParameter and IS_VALID_COMPILATION_PARAMETER_IMPL 2. Replace GridwiseGemm to GridwiseGemm32 and GridwiseGemm64, and use one of them according to current GPU target 3. Move gridwise gemm related variable from Argument member to local variable in RunImp - It is to avoid duplicated GridwiseGemm::CheckValidity 4. Add IsValidGemmCompilationParameter to all XDL kernels. Know issues: - DeviceBatchedGemmXdl and DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle are incorrect on gfx11. - DeviceGemmMultipleDLayernorm_Xdl_CShuffle are incorrect on both gfx11 and gfx12.
This commit is contained in:
@@ -108,10 +108,8 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static_assert(MWaves > 0);
|
||||
static_assert(NWaves > 0);
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
|
||||
@@ -49,10 +49,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static_assert(MWaves > 0);
|
||||
static_assert(NWaves > 0);
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
|
||||
Reference in New Issue
Block a user