Extend XDL kernel to Support RDNA3/4 - Part 2 (#2722)

Update Blockwise and Gridwise files to support both wave32 & wave64.

1. Calculate WaveSize from template parameter, instead of hard code it to 64, some "64" is also replace with WaveSize
2. Move BN0Shuffled and BK0Shuffled to device side. we can't get correct mfma inst info in host side.
3. Update b_thread_offset_n and b_thread_offset_k in gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp for gfx11. in gfx11, input data is duplicated for each 16 threads, it is different with all of others.
4. Modify a1_threadwise_copy in gridwise_batched_*gemm*gemm for gfx11.  for gfx11, we need duplicate input and swizzle A if transposeC isn't enabled.
This commit is contained in:
linqunAMD
2025-09-04 08:33:40 +08:00
committed by GitHub
parent 80ce6a573b
commit e2d28a92af
48 changed files with 605 additions and 300 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -38,13 +38,15 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
@@ -54,9 +56,6 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
MRepeat * NRepeat,

View File

@@ -46,7 +46,9 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
@@ -77,9 +79,6 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
// Hardcode to 2, for better 8-bit access pattern
static constexpr index_t MXdlPack = 2;
@@ -206,6 +205,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
@@ -213,6 +213,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
#endif
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -32,9 +32,9 @@ template <index_t BlockSize,
index_t KPerXDL>
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
{
static constexpr index_t WaveSize = 64;
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / (WaveNumM * WaveNumN);
static constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
@@ -108,7 +108,11 @@ struct BlockwiseGemmXdlops_pipeline_v4
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(MWaves > 0);
static_assert(NWaves > 0);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
@@ -121,9 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v4
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
@@ -237,6 +238,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
@@ -245,7 +247,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
#endif
// HotLoopInstList::Print();
}

View File

@@ -141,6 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -153,7 +154,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(

View File

@@ -143,6 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
using Base::BMmaKStride;
using Base::MWaves;
using Base::WaveSize;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType>{};
@@ -159,7 +160,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(

View File

@@ -142,6 +142,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::c_thread_desc_;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -154,7 +155,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(

View File

@@ -144,6 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
using Base::BMmaKStride;
using Base::c_thread_desc_;
using Base::MWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -156,7 +157,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -145,6 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
using Base::BMmaKStride;
using Base::MWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -158,7 +159,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -143,6 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::MWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -155,7 +156,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -142,6 +142,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 2;
@@ -154,7 +155,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -151,6 +151,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -164,7 +165,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -93,8 +93,10 @@ struct BlockwiseGemmXdlops_pipeline_base
NPerXDL,
xdlops_gemm.KPerXdlops>;
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
#endif
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,

View File

@@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t LocalPrefetchStages = 2;
@@ -166,7 +167,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MWaves;
using Base::WaveSize;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K1 = WaveSize / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -49,7 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(MWaves > 0);
static_assert(NWaves > 0);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
@@ -66,9 +70,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
@@ -158,6 +159,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
@@ -173,6 +175,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(
KPack % (16 * sizeof(ComputeTypeA)) == 0,
"KPack must be divisbile by number of elements processed in single smfmac instruction");
#endif
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -49,8 +49,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
@@ -61,14 +59,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
@@ -679,7 +678,9 @@ struct BlockwiseGemmXdlops_v2
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
@@ -691,9 +692,6 @@ struct BlockwiseGemmXdlops_v2
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
@@ -790,6 +788,7 @@ struct BlockwiseGemmXdlops_v2
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
@@ -798,6 +797,7 @@ struct BlockwiseGemmXdlops_v2
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
#endif
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -30,8 +30,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t KPerBlock = K0PerBlock * KPack;
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
@@ -42,8 +40,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,