mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 2 (#2722)
Update Blockwise and Gridwise files to support both wave32 & wave64. 1. Calculate WaveSize from template parameter, instead of hard code it to 64, some "64" is also replace with WaveSize 2. Move BN0Shuffled and BK0Shuffled to device side. we can't get correct mfma inst info in host side. 3. Update b_thread_offset_n and b_thread_offset_k in gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp for gfx11. in gfx11, input data is duplicated for each 16 threads, it is different with all of others. 4. Modify a1_threadwise_copy in gridwise_batched_*gemm*gemm for gfx11. for gfx11, we need duplicate input and swizzle A if transposeC isn't enabled.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -38,13 +38,15 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
@@ -54,9 +56,6 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MRepeat * NRepeat,
|
||||
|
||||
@@ -46,7 +46,9 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
|
||||
static constexpr index_t WaveSize = 64;
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
@@ -77,9 +79,6 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
static constexpr index_t KPerInnerLoop = KPack;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
// Hardcode to 2, for better 8-bit access pattern
|
||||
|
||||
static constexpr index_t MXdlPack = 2;
|
||||
@@ -206,6 +205,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
@@ -213,6 +213,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -32,9 +32,9 @@ template <index_t BlockSize,
|
||||
index_t KPerXDL>
|
||||
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
|
||||
{
|
||||
static constexpr index_t WaveSize = 64;
|
||||
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / (WaveNumM * WaveNumN);
|
||||
|
||||
static constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
|
||||
@@ -108,7 +108,11 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static_assert(MWaves > 0);
|
||||
static_assert(NWaves > 0);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
@@ -121,9 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
@@ -237,6 +238,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
@@ -245,7 +247,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
|
||||
#endif
|
||||
// HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
|
||||
@@ -141,6 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -153,7 +154,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -143,6 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType>{};
|
||||
@@ -159,7 +160,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -142,6 +142,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::c_thread_desc_;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -154,7 +155,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -144,6 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
using Base::BMmaKStride;
|
||||
using Base::c_thread_desc_;
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -156,7 +157,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -145,6 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -158,7 +159,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -143,6 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -155,7 +156,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -142,6 +142,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 3;
|
||||
static constexpr index_t PrefillStages = 2;
|
||||
@@ -154,7 +155,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -151,6 +151,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -164,7 +165,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -93,8 +93,10 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
NPerXDL,
|
||||
xdlops_gemm.KPerXdlops>;
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(KPerThread % KPack == 0,
|
||||
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
|
||||
#endif
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
|
||||
@@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t LocalPrefetchStages = 2;
|
||||
@@ -166,7 +167,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MWaves;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K1 = WaveSize / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -49,7 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static_assert(MWaves > 0);
|
||||
static_assert(NWaves > 0);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
@@ -66,9 +70,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
@@ -158,6 +159,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
@@ -173,6 +175,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static_assert(
|
||||
KPack % (16 * sizeof(ComputeTypeA)) == 0,
|
||||
"KPack must be divisbile by number of elements processed in single smfmac instruction");
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -49,8 +49,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
@@ -61,14 +59,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
@@ -679,7 +678,9 @@ struct BlockwiseGemmXdlops_v2
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
@@ -691,9 +692,6 @@ struct BlockwiseGemmXdlops_v2
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
static_assert(KPerThread % KPack == 0,
|
||||
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
|
||||
|
||||
@@ -790,6 +788,7 @@ struct BlockwiseGemmXdlops_v2
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
@@ -798,6 +797,7 @@ struct BlockwiseGemmXdlops_v2
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -30,8 +30,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t KPerBlock = K0PerBlock * KPack;
|
||||
|
||||
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
@@ -42,8 +40,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
|
||||
Reference in New Issue
Block a user