mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Add MoE & FP8 Blockscale WP Kernels for GFX950 (#2297)
* [fix] align v3 gufusion pipeline * fix device kernel selection. * Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE * experimental optimization for scale load in blkscale gemm * Add asm for no-loop v3_128x128x128 * fix bugs * tune fp8 example * Update v1_128x128x128 to 2x2 instead of 4x1 * wip * add warmup to asm launch * wip2 * 16x16 function merged to moe * temp save, a performant version. * wip3 * Update .co binary to 16x16 * 16x16x128 correct; 64x64x128 failed * update * use mem_op::set when topk=1 * add mx fp8 b_preshuffle support, function not yet tested. * Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints * some fixes * fix update * remove some unnecessary hacky; enable 256x256x256 tilesize * update for function debug * Add pipeline v3. Have some runtime issue and register spill * Fix pipe v3 correctness issue * remove unnecessary hacky * clang format * fix a bug * fix the bug, functional test passed * tempsave; buggy at passed 4 e8m0 to scaled mfma * added fp4_bpreshuffle example, build failures * fixed some bugs * implement shuffled scale mxfp4gemm, blocker: opsel not effect * hotfix * fix bugs, build passed * (M, N, K)=(128, 128, 128) function failed. * temp save for gemm1. Function not ready * fix compile error. Gemm2 pass. Gemm1 WIP * fix bug for a lds read * update moe * Compile pass. Gemm1 function WIP * update moe * fix fp8; fix even/odd * tempsave * update moe * Revert "update" This reverts commit960b2bce1c. * Revert "use mem_op::set when topk=1" This reverts commitdef952a178. * Add v3 128x128x128_4x4_16x16.co for gfx950 * temp cmake flag suppression for aiter test * add code for mxfp4 gemm, blockscale not supported yet * gemm1 up-only pass. GU WIP * function pass with inline asm hacky * revert unexpected file change * updated and build passed * update CE elementOP * added code for debug * Gemm1 GUFusion function pass. Perf WIP * Fix fp8/bf8; remove duplicated code * disable the scheduler in v3; bring it back when compiler feature ready. * update moe v1 pipeline * Add gemm1 v1 32x128x128 * remove schedule barrier * updated * Fix fp8/bf8 B-row * mfma using asm, device result correct, host result need to check * gemm1 v3 64x128x128 debug * fix cpu ref * a/b thread_desc stride fix * Use random scale for init1 * 16x16x128 input size blockscale function passed * fix blockscale gemm bug * tempsave. Almost all instances passed. * v1 fix for mi350. * temp save * debug save * update debug * fix the bug, 128x128x256 tile function passed * v3 * rename moe block selector and pipeline * Add gemm1 v1 * Add gemm1 v1 to selector * added mx moe block v3 support, function passed * compile error fix * Improve the pipeline * Pack e8m0 as int32_t * v1 compile pass. Function not ready * debug synchronize issue over different GPU/ROCm * minor fix * Add profiler filter * Add f4 ckProfiler * Fix example compile error * Add f4 profiler examples * tempsave * v1 function pass. * v3 function pass * align file and function name * mx_moe_fp4 ready for aiter with clang-format. * modify the way we represent fp4 * generalize the pipeline scheduling. * init moe mx f4 scale shuffle * Cmakelist diable compiler-bound flags * mx_fp4 default parameter change * Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler. * update code * tempsave; modify the way we represent fp4 * generalize the pipeline scheduling. * Add gemm1 gfx942 .co support * updated code, build passed. * Update gemm2 asm with latest compiler flag * Fix mx f4 ckProfiler * Fix blockwise gemm mx v1 * lds conflict free + buffer load lds * Add gemm2 v3 64x128x128 * fix a, b scale loading bugs, a, b scale loading now correctly * Add gemm2 v3 64x128x128 * commit with debug info * fix fp4 profiler * Add mx fp4 pileline v1 instances * Fix v2 topk_weight cal. Add silu asm. * v2 tok_weight WIP * init mx fp4 B no preshuffle version * tempsave. compile pass, function wrong * enable fp4 moe no weigth preshuffle, function pass * update the TFlops calculation in the example * Add gemm2 64x128x128 asm. Fix BF16 ref. * fix 2 typos in fp4_preshuffle * Better kernel selection in device classes * correct preShuffleBuffer we should used packed k to do shuffle. * lds conflict free + buffer load lds * optimize offset math in dma * Fix fp4 ckProfiler * Fix MX MFMA tests * fix f4 pipeline issues * gemm1 func pass * update mx moe gemm1_bns tile size to 64x128x256 * update mx moe gemm1 gemm2 TF and BW calculation * fix typo * temp save * Fix example_gemm_mx build * rename the block pipeline * correct a typo in tail * Add rotating to mx examples * fix the correctness issue * Fix v1; use M padding * Add NT flag to B/BScale buffer * Merge gemm_mx_common.hpp * temp save, 4.4~4.5 * Fix 'Merge gemm_mx_common.hpp' * refactor the pipeline * Pad the M for scale buffer unconditionaly * update MX moe GEMM1 hotloopscheduling * change the gemm1 tile from 64x128x128 to 128x64x128 * Unconditional Ascale padding * Pad shuffled a scale only * pad ascale * add vmcnt guard for async copy * Profiler add f4 wp * Merge preshuffle device * Add more fp4 wp instances * Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format * Clang-format after 2 merges * Remove rocm6.3 workaround flags and macro * Fix fp8 config * Fix bf8 config * flag and barrier fix for copmiler branch MainOpSelV3 * Add fp8 profiler instances * Remove debug infos; Enable flags for blockscale f8 * No asm ver. for merging moe blocksale fp8 into mainline * update the flag name for f8blockscale * recover example * fix performance bug of bpreshuffle f8 gemm * clang format, remove single rate mfma restriction for f8 * remove single rate mfma restriction for f8 blockscale gemm * Fix moe blockscale gemm1 barrier 0x800 for new compiler * add pipeline v1 for MOE Gemm2 * Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns * Fix OOB; add MB96 instances * remove unnecessary files * fix the cmake issue * Enable splitk for mxfp4; clang format; * Generate random tensor values with multiple threads * Use packed_size_v for A/BPackedSize * Fix warning * Fix target_compile_options for disabled target on gfx942 * fix moe pki4 on gfx950 * doc the kGroup definition * Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale) * Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example * Fix unknown compiler flag * fix two failed examples. * fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16 * workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major * fix test_gemm_splitk; * Fix compile for mx_mfma_op * add mfma selection logic for multipled_v3 * Clean up * Fix device gemm mx link error * improve the global atomic pattern * Revert unnecessary copyright updates * restore minimum_occupancy logic * Avoid data race in moe gemm2 ref * Build fp8 gemm_multiply_multiply and moe only on gfx94/95 * update the instance in device_mx_gemm * Resolve comments * Copyright 2025 * Remove unused code * fix library linking issue --------- Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -153,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
@@ -290,12 +291,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -388,12 +391,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -477,12 +483,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -588,7 +596,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
|
||||
@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -154,9 +155,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
@@ -298,12 +299,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -382,12 +385,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -458,12 +464,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -556,7 +565,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
|
||||
@@ -0,0 +1,952 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? HotLoopInstList::A_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
|
||||
|
||||
static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
|
||||
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
|
||||
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
|
||||
|
||||
// constexpr auto num_dsread_a_mfma =
|
||||
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
|
||||
constexpr auto num_total_stages = MRepeat;
|
||||
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
|
||||
constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
|
||||
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
|
||||
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
|
||||
|
||||
constexpr auto buffer_load_stages_more =
|
||||
(num_buffer_load_inst_a + num_buffer_load_inst_b) -
|
||||
math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
|
||||
(num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
|
||||
constexpr auto buffer_load_b_stages =
|
||||
buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
|
||||
? num_buffer_load_inst_b / buffer_load_perstage_more
|
||||
: (buffer_load_stages_more +
|
||||
(num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
|
||||
buffer_load_perstage_less);
|
||||
|
||||
constexpr auto buffer_load_a_stages =
|
||||
num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
|
||||
|
||||
constexpr auto buffer_load_issue_point_b = 0;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto ds_write_issue_point = 0;
|
||||
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
|
||||
|
||||
// B global read
|
||||
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
if constexpr(((i < buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_more ==
|
||||
buffer_load_issue_point_b)) ||
|
||||
((i >= buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_less ==
|
||||
buffer_load_issue_point_b)))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// A global read + A local write
|
||||
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_more ==
|
||||
ds_write_issue_point)) ||
|
||||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_less ==
|
||||
ds_write_issue_point)))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_more ==
|
||||
buffer_load_issue_point_a)) ||
|
||||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_less ==
|
||||
buffer_load_issue_point_a)))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Stage>
|
||||
__device__ static constexpr auto EpilogueScheduler_1(Stage stage)
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b =
|
||||
MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
|
||||
|
||||
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
|
||||
|
||||
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
|
||||
constexpr auto staged_num_mfma = num_mfma / MRepeat;
|
||||
|
||||
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
|
||||
|
||||
if constexpr(stage.value == 0)
|
||||
{
|
||||
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
|
||||
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
|
||||
constexpr auto staged_num_mfma_per_buffer_load_b =
|
||||
staged_num_mfma / num_buffer_load_inst_b;
|
||||
// B global
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
|
||||
static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) {
|
||||
ignore = ibuf_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(stage.value == 1)
|
||||
{
|
||||
constexpr auto staged_num_mfma_per_ds_write_a =
|
||||
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
|
||||
|
||||
constexpr auto stage_more_mfma =
|
||||
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
|
||||
|
||||
// A local write
|
||||
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
|
||||
if constexpr(i_inst.value < stage_more_mfma)
|
||||
{
|
||||
if(i_inst.value < staged_num_ds_read_inst_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(i_inst.value < staged_num_ds_read_inst_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
}
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// A local Read
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto EpilogueScheduler_2()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
|
||||
|
||||
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
|
||||
constexpr auto staged_num_mfma = num_mfma / MRepeat;
|
||||
|
||||
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
|
||||
|
||||
// A local Read
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_buf;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
// Global prefetch A1 B1
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// // Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
|
||||
|
||||
// // Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, 2, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2,
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(m0.value == MRepeat - 2)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(local_read_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(m0.value == (MRepeat - 1))
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(local_read_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(mfma_reg_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
if constexpr(m0.value == (MRepeat - 2))
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(m0.value == MRepeat - 1)
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(m0.value < (MRepeat - 2))
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(
|
||||
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(m0.value < (MRepeat - 2))
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(
|
||||
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
// Reduce the vgpr usage here.
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I2, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,919 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::APackedSize;
|
||||
using Base::BPackedSize;
|
||||
using Base::ComputePackedSize;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
BScaleThreadTransfer& b_scale_thread_copy_up,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleGridBuffer& b_scale_grid_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
ignore = a_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf_up;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs_up;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 0
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(
|
||||
0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk *
|
||||
xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I1][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
// using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,155 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
|
||||
// must be used for compute
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
|
||||
{
|
||||
|
||||
// Hardware MX GEMM pipeline
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
;
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,813 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::APackedSize;
|
||||
using Base::BPackedSize;
|
||||
using Base::ComputePackedSize;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I0)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I1)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(
|
||||
0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk *
|
||||
xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
// using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp"
|
||||
@@ -171,26 +172,54 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
|
||||
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
|
||||
if constexpr(std::is_same<ADataType, BDataType>::value)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp"
|
||||
namespace ck {
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
constexpr auto BlockGemmBlockScaleBPreshufflePipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
#if 0
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v2<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
#endif
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
|
||||
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,864 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
|
||||
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename AScaleThreadTransferStep,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadDesc,
|
||||
typename BScaleThreadTransfer,
|
||||
typename BScaleThreadTransferStep>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
const AScaleThreadDesc& a_scale_thread_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const AScaleThreadTransferStep& a_scale_thread_copy_step,
|
||||
// BScaleThreadCopy
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
const BScaleThreadDesc& b_scale_thread_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block +
|
||||
k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<
|
||||
b_thread_desc_.CalculateOffset(make_tuple(
|
||||
n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
|
||||
[&](auto t) {
|
||||
using pk_fma_type =
|
||||
typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) =
|
||||
__builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec
|
||||
.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf
|
||||
.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,186 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp"
|
||||
namespace ck {
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v2<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
#endif
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,854 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
|
||||
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename AScaleThreadTransferStep,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadDesc,
|
||||
typename BScaleThreadTransfer,
|
||||
typename BScaleThreadTransferStep>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
const AScaleThreadDesc& a_scale_thread_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const AScaleThreadTransferStep& a_scale_thread_copy_step,
|
||||
// BScaleThreadCopy
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
const BScaleThreadDesc& b_scale_thread_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block +
|
||||
k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<
|
||||
b_thread_desc_.CalculateOffset(make_tuple(
|
||||
n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
|
||||
[&](auto t) {
|
||||
using pk_fma_type =
|
||||
typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) =
|
||||
__builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec
|
||||
.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf
|
||||
.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,130 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
|
||||
// must be used for compute
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmMXNBSPipeline_Selector()
|
||||
{
|
||||
|
||||
// Hardware MX GEMM pipeline
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1<BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,664 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 1
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_m3_k;
|
||||
using Base::b_block_desc_n0_n1_n2_n3_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::APackedSize;
|
||||
using Base::BMmaKStride;
|
||||
using Base::BPackedSize;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::KXdlPack;
|
||||
using Base::MXdlPack;
|
||||
using Base::NXdlPack;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple5 = typename Base::Tuple5;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun =
|
||||
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
|
||||
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
// -------------------------------------------------------------------------------------------
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read block data in chunks to assemble correct thread vectors
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto b_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_buf[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_buf[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
|
||||
ikxdl * NXdlPack + inxdl>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(
|
||||
Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read block data in chunks to assemble correct thread vectors
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto b_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_buf[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_buf[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
|
||||
ikxdl * NXdlPack + inxdl>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat / NXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -532,6 +532,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
@@ -560,8 +563,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user