From b40267f7a67aca9adf7fe679cd846de2f32a7971 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Thu, 12 Jun 2025 09:25:59 +0800 Subject: [PATCH] Add MoE & FP8 Blockscale WP Kernels for GFX950 (#2297) * [fix] align v3 gufusion pipeline * fix device kernel selection. * Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE * experimental optimization for scale load in blkscale gemm * Add asm for no-loop v3_128x128x128 * fix bugs * tune fp8 example * Update v1_128x128x128 to 2x2 instead of 4x1 * wip * add warmup to asm launch * wip2 * 16x16 function merged to moe * temp save, a performant version. * wip3 * Update .co binary to 16x16 * 16x16x128 correct; 64x64x128 failed * update * use mem_op::set when topk=1 * add mx fp8 b_preshuffle support, function not yet tested. * Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints * some fixes * fix update * remove some unnecessary hacky; enable 256x256x256 tilesize * update for function debug * Add pipeline v3. Have some runtime issue and register spill * Fix pipe v3 correctness issue * remove unnecessary hacky * clang format * fix a bug * fix the bug, functional test passed * tempsave; buggy at passed 4 e8m0 to scaled mfma * added fp4_bpreshuffle example, build failures * fixed some bugs * implement shuffled scale mxfp4gemm, blocker: opsel not effect * hotfix * fix bugs, build passed * (M, N, K)=(128, 128, 128) function failed. * temp save for gemm1. Function not ready * fix compile error. Gemm2 pass. Gemm1 WIP * fix bug for a lds read * update moe * Compile pass. Gemm1 function WIP * update moe * fix fp8; fix even/odd * tempsave * update moe * Revert "update" This reverts commit c7d79dcb672616d9bc0fd9958f714fc80e7c84fd. * Revert "use mem_op::set when topk=1" This reverts commit 8c7772860735001a51421e7b6d0a28f6676d6c40. * Add v3 128x128x128_4x4_16x16.co for gfx950 * temp cmake flag suppression for aiter test * add code for mxfp4 gemm, blockscale not supported yet * gemm1 up-only pass. GU WIP * function pass with inline asm hacky * revert unexpected file change * updated and build passed * update CE elementOP * added code for debug * Gemm1 GUFusion function pass. Perf WIP * Fix fp8/bf8; remove duplicated code * disable the scheduler in v3; bring it back when compiler feature ready. * update moe v1 pipeline * Add gemm1 v1 32x128x128 * remove schedule barrier * updated * Fix fp8/bf8 B-row * mfma using asm, device result correct, host result need to check * gemm1 v3 64x128x128 debug * fix cpu ref * a/b thread_desc stride fix * Use random scale for init1 * 16x16x128 input size blockscale function passed * fix blockscale gemm bug * tempsave. Almost all instances passed. * v1 fix for mi350. * temp save * debug save * update debug * fix the bug, 128x128x256 tile function passed * v3 * rename moe block selector and pipeline * Add gemm1 v1 * Add gemm1 v1 to selector * added mx moe block v3 support, function passed * compile error fix * Improve the pipeline * Pack e8m0 as int32_t * v1 compile pass. Function not ready * debug synchronize issue over different GPU/ROCm * minor fix * Add profiler filter * Add f4 ckProfiler * Fix example compile error * Add f4 profiler examples * tempsave * v1 function pass. * v3 function pass * align file and function name * mx_moe_fp4 ready for aiter with clang-format. * modify the way we represent fp4 * generalize the pipeline scheduling. * init moe mx f4 scale shuffle * Cmakelist diable compiler-bound flags * mx_fp4 default parameter change * Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler. * update code * tempsave; modify the way we represent fp4 * generalize the pipeline scheduling. * Add gemm1 gfx942 .co support * updated code, build passed. * Update gemm2 asm with latest compiler flag * Fix mx f4 ckProfiler * Fix blockwise gemm mx v1 * lds conflict free + buffer load lds * Add gemm2 v3 64x128x128 * fix a, b scale loading bugs, a, b scale loading now correctly * Add gemm2 v3 64x128x128 * commit with debug info * fix fp4 profiler * Add mx fp4 pileline v1 instances * Fix v2 topk_weight cal. Add silu asm. * v2 tok_weight WIP * init mx fp4 B no preshuffle version * tempsave. compile pass, function wrong * enable fp4 moe no weigth preshuffle, function pass * update the TFlops calculation in the example * Add gemm2 64x128x128 asm. Fix BF16 ref. * fix 2 typos in fp4_preshuffle * Better kernel selection in device classes * correct preShuffleBuffer we should used packed k to do shuffle. * lds conflict free + buffer load lds * optimize offset math in dma * Fix fp4 ckProfiler * Fix MX MFMA tests * fix f4 pipeline issues * gemm1 func pass * update mx moe gemm1_bns tile size to 64x128x256 * update mx moe gemm1 gemm2 TF and BW calculation * fix typo * temp save * Fix example_gemm_mx build * rename the block pipeline * correct a typo in tail * Add rotating to mx examples * fix the correctness issue * Fix v1; use M padding * Add NT flag to B/BScale buffer * Merge gemm_mx_common.hpp * temp save, 4.4~4.5 * Fix 'Merge gemm_mx_common.hpp' * refactor the pipeline * Pad the M for scale buffer unconditionaly * update MX moe GEMM1 hotloopscheduling * change the gemm1 tile from 64x128x128 to 128x64x128 * Unconditional Ascale padding * Pad shuffled a scale only * pad ascale * add vmcnt guard for async copy * Profiler add f4 wp * Merge preshuffle device * Add more fp4 wp instances * Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format * Clang-format after 2 merges * Remove rocm6.3 workaround flags and macro * Fix fp8 config * Fix bf8 config * flag and barrier fix for copmiler branch MainOpSelV3 * Add fp8 profiler instances * Remove debug infos; Enable flags for blockscale f8 * No asm ver. for merging moe blocksale fp8 into mainline * update the flag name for f8blockscale * recover example * fix performance bug of bpreshuffle f8 gemm * clang format, remove single rate mfma restriction for f8 * remove single rate mfma restriction for f8 blockscale gemm * Fix moe blockscale gemm1 barrier 0x800 for new compiler * add pipeline v1 for MOE Gemm2 * Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns * Fix OOB; add MB96 instances * remove unnecessary files * fix the cmake issue * Enable splitk for mxfp4; clang format; * Generate random tensor values with multiple threads * Use packed_size_v for A/BPackedSize * Fix warning * Fix target_compile_options for disabled target on gfx942 * fix moe pki4 on gfx950 * doc the kGroup definition * Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale) * Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example * Fix unknown compiler flag * fix two failed examples. * fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16 * workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major * fix test_gemm_splitk; * Fix compile for mx_mfma_op * add mfma selection logic for multipled_v3 * Clean up * Fix device gemm mx link error * improve the global atomic pattern * Revert unnecessary copyright updates * restore minimum_occupancy logic * Avoid data race in moe gemm2 ref * Build fp8 gemm_multiply_multiply and moe only on gfx94/95 * update the instance in device_mx_gemm * Resolve comments * Copyright 2025 * Remove unused code * fix library linking issue --------- Co-authored-by: OscarXu Co-authored-by: lalala-sh Co-authored-by: mtgu0705 Co-authored-by: aska-0096 Co-authored-by: Your Name Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: feifei14119 Co-authored-by: Lin, Qun Co-authored-by: Andriy Roshchenko Co-authored-by: joye Co-authored-by: asleepzzz [ROCm/composable_kernel commit: 37554c31e8e1cd3732bb6e51d3ea1c39cbe66b0e] --- cmake/EnableCompilerWarnings.cmake | 5 +- .../01_gemm/gemm_xdl_lds_direct_load_fp16.cpp | 4 +- .../65_gemm_multiply_multiply/CMakeLists.txt | 37 +- ...emm_multiply_multiply_xdl_fp8_ab_scale.cpp | 16 +- ...ultiply_xdl_fp8_blockscale_bpreshuffle.cpp | 372 +++ ..._multiply_multiply_xdl_fp8_bpreshuffle.cpp | 10 +- .../moe_gemm1_xdl_fp8.cpp | 50 +- .../moe_gemm1_xdl_fp8_blockscale.cpp | 548 ++++ .../moe_gemm2_xdl_fp8.cpp | 29 +- .../moe_gemm2_xdl_fp8_blockscale.cpp | 541 ++++ example/67_gemm_microscaling/CMakeLists.txt | 30 +- .../67_gemm_microscaling/gemm_mx_common.hpp | 34 +- example/67_gemm_microscaling/gemm_mx_fp4.cpp | 2 - .../gemm_mx_fp4_bpreshuffle.cpp | 8 +- .../moe_gemm1_xdl_mx_fp4_bns.cpp | 545 ++++ .../moe_gemm2_xdl_mx_fp4_bns.cpp | 526 +++ example/CMakeLists.txt | 10 +- include/ck/library/utility/host_tensor.hpp | 68 + .../library/utility/host_tensor_generator.hpp | 12 + include/ck/library/utility/thread.hpp | 25 + ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 50 +- ...peline_xdlops_b_preshuffle_gufusion_v1.hpp | 51 +- ...peline_xdlops_b_preshuffle_gufusion_v3.hpp | 952 ++++++ ...xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp | 919 ++++++ ...xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp | 1020 ++++++ ...ne_xdlops_b_preshuffle_mx_moe_selector.hpp | 155 + ...pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp | 813 +++++ ...pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp | 1032 ++++++ ..._pipeline_xdlops_b_preshuffle_selector.hpp | 69 +- ...dlops_blockscale_b_preshuffle_selector.hpp | 123 + ...line_xdlops_blockscale_b_preshuffle_v1.hpp | 864 +++++ ...line_xdlops_blockscale_b_preshuffle_v3.hpp | 1090 +++++++ ...oe_blockscale_b_preshuffle_gufusion_v1.hpp | 1036 ++++++ ...oe_blockscale_b_preshuffle_gufusion_v3.hpp | 1203 +++++++ ...s_moe_blockscale_b_preshuffle_selector.hpp | 186 ++ ..._xdlops_moe_blockscale_b_preshuffle_v1.hpp | 854 +++++ ..._xdlops_moe_blockscale_b_preshuffle_v3.hpp | 1070 +++++++ ...pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp | 1361 ++++++++ ...mm_pipeline_xdlops_mx_moe_nbs_selector.hpp | 130 + ...ise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp | 664 ++++ ...ise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp | 1126 +++++++ ...kwise_gemm_pipeline_xdlops_v3_ab_scale.hpp | 8 +- .../gpu/device/device_gemm_multiple_d.hpp | 48 +- .../device_gemm_multiple_d_ab_scale.hpp | 45 +- ...xdl_cshuffle_v3_blockscale_bpreshuffle.hpp | 507 +++ .../impl/device_moe_gemm_blockscale.hpp | 584 ++++ .../gpu/device/impl/device_moe_mx_gemm.hpp | 571 ++++ .../device/impl/device_moe_mx_gemm_bns.hpp | 540 ++++ ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 6 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 10 +- ...fle_v3_multi_d_blockscale_b_preshuffle.hpp | 2080 ++++++++++++ ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 62 +- .../gpu/grid/gridwise_moe_gemm.hpp | 354 +- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 2668 +++++++++++++++ .../gpu/grid/gridwise_moe_mx_gemm.hpp | 2652 +++++++++++++++ .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 2849 +++++++++++++++++ .../threadwise_tensor_slice_transfer.hpp | 5 - .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 9 - include/ck/utility/amd_xdlops.hpp | 100 +- include/ck/utility/data_type.hpp | 11 - include/ck/utility/debug.hpp | 13 + include/ck/utility/dtype_vector.hpp | 14 +- include/ck/utility/functional2.hpp | 3 +- .../cpu/reference_moe_gemm1_blockscale.hpp | 280 ++ .../cpu/reference_moe_gemm2.hpp | 11 +- .../cpu/reference_moe_gemm2_blockscale.hpp | 248 ++ .../cpu/reference_moe_mx_gemm1.hpp | 264 ++ .../cpu/reference_moe_mx_gemm2.hpp | 238 ++ .../add_device_operation_instance.hpp | 17 +- .../gpu/gemm_blockscale_wp.hpp | 172 + .../gpu/gemm_blockscale_wp/CMakeLists.txt | 16 + ...wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 89 + ...k_mn_128_128_128_comp_default_instance.cpp | 38 + ..._mn_128_128_128_comp_kpadding_instance.cpp | 38 + ...mn_128_128_128_mem_v1_default_instance.cpp | 39 + ...n_128_128_128_mem_v1_kpadding_instance.cpp | 39 + ...evice_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp | 33 +- .../device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp | 3 +- ...device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp | 3 +- .../device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp | 3 +- .../profile_gemm_blockscale_wp_impl.hpp | 415 +++ .../include/profiler/profile_gemm_mx_impl.hpp | 20 +- profiler/src/CMakeLists.txt | 2 + profiler/src/profile_gemm_blockscale_wp.cpp | 184 ++ test/mx_mfma_op/mx_mfma_op.hpp | 8 +- 85 files changed, 32508 insertions(+), 431 deletions(-) create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp create mode 100644 example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp create mode 100644 example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp create mode 100644 example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp create mode 100644 example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp create mode 100644 include/ck/library/utility/thread.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp create mode 100644 profiler/src/profile_gemm_blockscale_wp.cpp diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d688..0c81f8df98 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,8 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + # Werror set outside by BUILD_DEV + # -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt @@ -108,7 +109,7 @@ else() endif() list(APPEND CMAKE_COMPILER_WARNINGS -Wno-missing-field-initializers - -Wno-deprecated-declarations + -Wno-error=deprecated-declarations ) endif() add_definitions(${CMAKE_COMPILER_WARNINGS}) diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp index 62037f7740..26ea31f20b 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on #else // clang-format off diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index a58612cb5b..36f1860e4f 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,11 +1,20 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) +set(EXAMPLE_COMPILE_OPTIONS) +# Open it when SGBPack branch landed on mainline +# list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental") +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) +add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp) +add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) list(APPEND gpu_list gfx942 gfx950) set(target 0) @@ -19,14 +28,32 @@ foreach(gpu IN LISTS GPU_TARGETS) if(HAS_MAX_ILP_SCHEDULING_STRATEGY) list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) endif() - target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) - target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) endif() set(GEMM_OPTIONS) list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") - target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) - target_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) - target_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) + example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) + example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) set(target 1) endif() endforeach() + +set(GEMM_OPTIONS) +list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") +set(BLOCKSCALE_GEMM_OPTIONS) +list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1") +check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL) +if(HAS_MAX_OCCUPANCY_EXPERIMENTAL) + list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental) +endif() +# list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1") +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) + +example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) +example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index b54ba5ddfb..5aa978fbf0 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - 16, 128, - 256, 16, 16, + 128, 128, + 128, 16, 16, 16, 16, - 1, 2, - S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 2, S<1, 16, 1, 16>, S<8>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp new file mode 100644 index 0000000000..d64266bccf --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using A1Layout = Col; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 1, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + exit(0); + } + + // Transpose the AScale tensor for better performance + ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AK, + A1Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = 0.0f; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + Tensor c_m_n({M, N}); + Tensor a_m_k({M, K}); + Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + +#if 1 + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } +#endif + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index 280697851b..fe1eca51b0 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -139,13 +139,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 128, 128, 128, + 256, 256, 128, 16, 16, - 32, 32, - 4, 1, + 16, 16, + 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 3b31460953..9fe9fdde78 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -158,21 +158,22 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 64; -static constexpr ck::index_t MNPerXDL = 16; -static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul -static constexpr bool MulRoutedWeight = false; -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); + +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -183,15 +184,15 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM // mn_perxdl MNPerXDL, MNPerXDL, // mn_xdlperwave - MXDLPerWave, NXDLPerWave, + MXDLPerWave, NXDLPerWave, // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 2, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; + 2, 2, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; // clang-format on @@ -205,9 +206,9 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 16; - ck::index_t valid_tile_num = 13; - ck::index_t tokens = 64; + ck::index_t sorted_tile_num = 256; + ck::index_t valid_tile_num = 256; + ck::index_t tokens = 16384; ck::index_t topk = 2; if(argc == 1) @@ -263,11 +264,12 @@ int main(int argc, char* argv[]) Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); max_token_id.mData = {valid_size}; - int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / (valid_tile_num / experts); } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; @@ -307,7 +309,7 @@ int main(int argc, char* argv[]) case 0: break; case 1: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp new file mode 100644 index 0000000000..c5328226ff --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +// using EDataType = F16; +using EDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const float& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d2); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = true; + +#if 0 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); +static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; +static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave; +static constexpr ck::index_t BLOCKSIZE = 256; + +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale + // clang-format off + < Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; +#else +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +#if 1 + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 6144; + ck::index_t experts = 8; + ck::index_t topk = 2; + // ck::index_t sorted_tile_num = 515; + // ck::index_t valid_tile_num = 512; + // ck::index_t tokens = 8192; + // ck::index_t sorted_tile_num = 15; + // ck::index_t valid_tile_num = 13; + ck::index_t sorted_tile_num = 259; + ck::index_t valid_tile_num = 256; + ck::index_t tokens = 4096; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7168; + ck::index_t experts = 256; + ck::index_t topk = 8; + ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 261; + ck::index_t valid_tile_num = 256; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); + max_token_id.mData = {valid_size}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N * 2}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; + std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + + sizeof(B0DataType) * K * N * 2 * experts + + sizeof(EDataType) * valid_tile_num * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k({tokens, K}); + Tensor b_e_n_k({experts, K, N * 2}); + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + // handle scale before ref. + for(int t = 0; t < tokens; ++t) + { + for(int k = 0; k < K; ++k) + { + a_t_k(t, k) = ck::type_convert(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K); + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N * 2; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm1BlockScale; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k, + b_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 42d892fe26..3188ba142c 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -123,11 +123,11 @@ using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 256; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t MXDLPerWave = 16; static constexpr ck::index_t NXDLPerWave = 4; -static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t NPerBlock = 256; static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); @@ -164,12 +164,12 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -186,11 +186,11 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 16; - ck::index_t valid_tile_num = 13; + ck::index_t sorted_tile_num = 133; + ck::index_t valid_tile_num = 128; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 128; + ck::index_t tokens = 16384; ck::index_t topk = 2; if(argc == 1) @@ -245,13 +245,14 @@ int main(int argc, char* argv[]) Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); - - max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; - + // max_token_id.mData[0] = valid_size; + // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts); } if(tokens * topk > valid_size) { diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp new file mode 100644 index 0000000000..354957c0d1 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -0,0 +1,541 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using EDataType = F16; +// using EDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +// using DsLayoutGate = ck::Tuple; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const float& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d2); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; +static constexpr bool MulRoutedWeight = true; + +#if 0 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t MXDLPerWave = 2; +static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); + +static constexpr ck::index_t CShuffleNLane = 16; +static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 2; +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; + +// clang-format off + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + AK1, BK1, + MNPerXDL, MNPerXDL, + MXDLPerWave, NXDLPerWave, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + +#else +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // tokens = 1 + // topk = 1 + // experts = 8 + // per expert: + + constexpr ck::index_t valid_tile_num = + 26; // 13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock + constexpr ck::index_t sorted_tile_num = valid_tile_num + 3; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; +#if 1 + // GEMM shape + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7160; + ck::index_t experts = 256; + ck::index_t tokens = 1; + ck::index_t topk = 8; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + + max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + // int eids[] = {0, 1, 3, 3, 3}; + // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + // int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + // 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + // 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, + // 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, + // 7, 7, + // 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k(HostTensorDescriptor( + {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + for(auto i = 0; i < N * K; i++) + { + b0_e_n_k.mData[i] = ck::type_convert(static_cast(0.1)); + b0_e_n_k.mData[i + N * K] = ck::type_convert(static_cast(0.2)); + } + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t num_btype = sizeof(A0DataType) * tokens * K * topk + + sizeof(B0DataType) * K * N * experts + + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k_k({tokens, topk, K}); + Tensor b_e_n_k({experts, K, N}); + Tensor c_t_n({tokens, N}); + + for(int t = 0; t < tokens; ++t) + { + for(int tk = 0; tk < topk; ++tk) + { + for(int k = 0; k < K; ++k) + { + a_t_k_k(t, tk, k) = ck::type_convert(a0_t_k_k(t, tk, k)) * + a1_t_k_k(t, tk, k / Scale_Block_K); + } + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm2BlockScale; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k_k, + b_e_n_k, + d2_e_n, + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 86d90674e1..34c54a7e12 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -6,8 +6,9 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) -#add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) -# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) TOFO: Fix RRR +# TODO: Fix RRR +# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) +# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) @@ -15,30 +16,23 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp4_bpreshuffle) -#add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp) -# add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4) TODO: Fix +add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns) -#add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp) -#add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns) - -#add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp) -# add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4) TODO: Fix - -#add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp) -#add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns) +add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns) set(FP4_MXGEMM_OPTIONS) list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1") -#list(APPEND FP4_MXGEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker -ftemplate-backtrace-limit=0) example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) -# example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) -# example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) -# example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) -# example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) + +example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) set(FP8_MXGEMM_OPTIONS) list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") -#list(APPEND FP8_MXGEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker -ftemplate-backtrace-limit=0) example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS}) example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 30df8ccd37..1f01e1c7be 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -250,7 +250,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c using AScaleLayout = Row; using BScaleLayout = Col; - auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize; + auto Scale_Padded_M = ck::math::integer_least_multiple(M, ScaleBlockSize); auto Scale_Stride_AM = f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); @@ -302,6 +302,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c return ck::type_convert(x); }; + using int_distr = std::uniform_int_distribution; + using float_distr = std::uniform_real_distribution; switch(config.init_method) { case 0: // Initializations for development and debugging @@ -320,22 +322,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c break; case 1: - - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] - b_k_n->GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5] + b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5] static_assert(ck::is_same_v); - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_2{120, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2} break; case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0}); + a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); - b_k_n->GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0}); + b_k_n_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); break; default: @@ -469,17 +468,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Comparing results..." << std::endl; } - // if(config.init_method == 0) - // { - // auto expected = static_cast(K); - // auto computed = type_convert(c_m_n_device_result(1, 12)); - - // res_verified = res_verified && std::abs(expected - computed) <= 0.0f; - // std::cout << "\nExpected vs Computed: " << expected << " vs " << computed - // << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl - // << std::endl; - // } - res_verified = res_verified && ck::utils::check_err( diff --git a/example/67_gemm_microscaling/gemm_mx_fp4.cpp b/example/67_gemm_microscaling/gemm_mx_fp4.cpp index cff5148fa7..65fbe3491a 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4.cpp @@ -5,8 +5,6 @@ using ADataType = ck::f4x2_pk_t; using BDataType = ck::f4x2_pk_t; -// using ADataType = ck::f4_t; -// using BDataType = ck::f4_t; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp index 562b2fdb17..6e1efd266b 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -5,8 +5,6 @@ using ADataType = ck::f4x2_pk_t; using BDataType = ck::f4x2_pk_t; -// using ADataType = ck::f4_t; -// using BDataType = ck::f4_t; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; @@ -74,9 +72,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle 16, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN 2, // CShuffleMXdlPerWavePerShuffle - 2, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleNXdlPerWavePerShuffle + S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlockW BlkGemmPSched, // BlkGemmPipeSched BlkGemmPVer, // BlkGemmPipelineVer ADataType, // ComputeTypeA diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp new file mode 100644 index 0000000000..24ab326391 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -0,0 +1,545 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t BlockSize = 256; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, BlockSize, + MPerBlock, NPerBlock, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 4096; + ck::index_t K = 6144; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{0.5f}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.5f}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp new file mode 100644 index 0000000000..6718581a50 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -0,0 +1,526 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 1cfe2789c2..56d709f41b 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -104,11 +104,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() - # Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 + # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_fp8_bpreshuffle") - message(DEBUG "Skipping ${source} example for current target") - list(REMOVE_ITEM FILE_NAME "${source}") + if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") + if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)") + message(DEBUG "Skipping ${source} example for current target") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() endif() endforeach() #only continue if there are some source files left on the list diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 257636d956..06e33afd20 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -18,6 +19,7 @@ #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/ranges.hpp" +#include "ck/library/utility/thread.hpp" template std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) @@ -512,6 +514,72 @@ struct Tensor } } + // Generate random values with multiple threads. Guaranteed to give the same sequence with any + // number of threads provided. + template , + typename Mapping = ck::identity, + typename Generator = std::minstd_rand> + void GenerateTensorDistr(Distribution dis = {0.f, 1.f}, + Mapping fn = {}, + const Generator g = Generator(0), // default seed 0 + std::size_t num_thread = -1) + { + using ck::math::integer_divide_ceil; + using ck::math::min; + if(num_thread == -1ULL) + num_thread = min(ck::get_available_cpu_cores(), 80U); // max 80 threads + // At least 2MB per thread + num_thread = min(num_thread, integer_divide_ceil(this->GetElementSpaceSize(), 0x200000)); + constexpr std::size_t BLOCK_BYTES = 64; + constexpr std::size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T); + + const std::size_t num_blocks = integer_divide_ceil(this->GetElementSpaceSize(), BLOCK_SIZE); + const std::size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread); + + std::vector threads; + threads.reserve(num_thread - 1); + const auto dst = const_cast(this->mData.data()); + const auto element_space_size = this->GetElementSpaceSize(); + for(int it = num_thread - 1; it >= 0; --it) + { + std::size_t ib_begin = it * blocks_per_thread; + std::size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks); + + auto job = [=]() { + auto g_ = g; // copy + auto dis_ = dis; // copy + g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v); + auto t_fn = [&]() { + if constexpr(ck::packed_size_v == 1) + return ck::type_convert(fn(dis_(g_))); + else if constexpr(ck::is_same_v) + return ck::f4x2_pk_t{ck::type_convert( + ck::float2_t{ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_)))})}; + else + static_assert(false, "Unsupported packed size for T"); + }; + + std::size_t ib = ib_begin; + for(; ib < ib_end - 1; ++ib) + ck::static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) { + constexpr size_t iw = iw_.value; + dst[ib * BLOCK_SIZE + iw] = t_fn(); + }); + for(std::size_t iw = 0; iw < BLOCK_SIZE; ++iw) + if(ib * BLOCK_SIZE + iw < element_space_size) + dst[ib * BLOCK_SIZE + iw] = t_fn(); + }; + + if(it > 0) + threads.emplace_back(std::move(job)); + else + job(); // last job run in the main thread + } + for(auto& t : threads) + t.join(); + } + template std::size_t GetOffsetFromMultiIndex(Is... is) const { diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index f48ba49bbf..ab69412c15 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -163,6 +163,18 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1; + + template + ck::e8m0_bexp_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + template struct GeneratorTensor_2 { diff --git a/include/ck/library/utility/thread.hpp b/include/ck/library/utility/thread.hpp new file mode 100644 index 0000000000..483c58c46f --- /dev/null +++ b/include/ck/library/utility/thread.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef __linux__ +#include +#endif +#include +namespace ck { +inline unsigned int get_available_cpu_cores() +{ +#if defined(__linux__) + cpu_set_t cpu_set; + if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0) + { + unsigned int cpu_count = CPU_COUNT(&cpu_set); + if(cpu_count > 0) + return cpu_count; + } +#endif + // Fallback if sched_getaffinity unavailable or fails + return std::thread::hardware_concurrency(); +} +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index 29750b8baa..4f7b8e768c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< using Base::B_K1; using Base::I0; using Base::I1; + using Base::KGroup; using Base::KRepeat; using Base::xdlops_gemm; using typename Base::HotLoopInstList; @@ -153,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -290,12 +291,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< block_sync_lds(); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); // B VGPR->VGPR dequant @@ -388,12 +391,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); // B VGPR->VGPR dequant @@ -477,12 +483,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); // B VGPR->VGPR dequant @@ -588,7 +596,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), - Sequence<1, 1, 1, 1, 1, KPack>, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp index 73749c6309..fe89e700c4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -298,12 +299,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -382,12 +385,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -458,12 +464,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -556,7 +565,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp new file mode 100644 index 0000000000..c76be74e52 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp @@ -0,0 +1,952 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + + template + __device__ static constexpr auto EpilogueScheduler_1(Stage stage) + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_b = + MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + if constexpr(stage.value == 0) + { + constexpr auto staged_num_buffer_load_b_per_ds_read_a = + num_buffer_load_inst_b / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_buffer_load_b = + staged_num_mfma / num_buffer_load_inst_b; + // B global + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { + ignore = ibuf_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 1) + { + constexpr auto staged_num_mfma_per_ds_write_a = + math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + + // A local write + static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + }); + __builtin_amdgcn_sched_barrier(0); + } + else + { + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + } + + __device__ static constexpr auto EpilogueScheduler_2() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == MRepeat - 2) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + HotLoopScheduler(); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == MRepeat - 1) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp new file mode 100644 index 0000000000..ac3b82f800 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp @@ -0,0 +1,919 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1< + BlockGemmPipelineScheduler::Intrawave, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::KThreadChunk; + + using Base::APackedSize; + using Base::BPackedSize; + using Base::ComputePackedSize; + + using AccType = typename Base::AccType; + using Tuple4 = typename Base::Tuple4; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + ignore = a_scale_grid_buf; + ignore = b_scale_grid_buf; + ignore = b_scale_grid_buf_up; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales to buf 0 + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(I0)); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, ScalesPerKBlockSize, 0)); + + // Prefetch b_scales to buf 0 + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(I0)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + + auto b_scale_thread_buf_copy_up = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy_up); + + b_scale_thread_bufs_up(I0)(Number{}) = + b_scale_thread_buf_copy_up[Number<0>{}]; + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + __builtin_amdgcn_sched_barrier(0); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Prefetch a_scales to buf 1 + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(I1)); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, ScalesPerKBlockSize, 0)); + + // Prefetch b_scales to buf 1 + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(I1)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + + auto b_scale_thread_buf_copy_up = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy_up); + + b_scale_thread_bufs_up(I1)(Number{}) = + b_scale_thread_buf_copy_up[Number<0>{}]; + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert( + 0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[mfma_reg_buf] + [Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[mfma_reg_buf] + [Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + // a thread copy + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * + xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Prefetch a_scales + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(mfma_reg_buf)); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0)); + + // Prefetch b_scales + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(mfma_reg_buf)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + + auto b_scale_thread_buf_copy_up = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy_up); + + b_scale_thread_bufs_up(mfma_reg_buf)(Number{}) = + b_scale_thread_buf_copy_up[Number<0>{}]; + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + // a thread copy + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I1][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I1][Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from a_scale_grid to a_scale_thread + static constexpr auto a_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from b_scale_grid to b_scale_thread_buf + static constexpr auto b_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + + protected: + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + // using Base::b_thread_desc_; + using Base::c_thread_desc_; + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp new file mode 100644 index 0000000000..f899c223b9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -0,0 +1,1020 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< + BlockGemmPipelineScheduler::Intrawave, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::KThreadChunk; + + using Base::APackedSize; + using Base::BPackedSize; + using Base::ComputePackedSize; + + using AccType = typename Base::AccType; + using Tuple4 = typename Base::Tuple4; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + + // Global prefetch B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch A1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Prefetch a_scales to buf 0 + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(I0)); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, ScalesPerKBlockSize, 0)); + + // Prefetch b_scales 1 + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(I0)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + + auto b_scale_thread_buf_copy_up = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy_up); + + b_scale_thread_bufs_up(I0)(Number{}) = + b_scale_thread_buf_copy_up[Number<0>{}]; + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + // restore col id and advance to the next set of scales + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); // vmem->vgpr-> lds0 + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + // Prefetch a_scales to buf 1 + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(local_read_buf)); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0)); + + // Prefetch b_scales 2 + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(local_read_buf)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + + auto b_scale_thread_buf_copy_up = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy_up); + + b_scale_thread_bufs_up(local_read_buf)(Number{}) = + b_scale_thread_buf_copy_up[Number<0>{}]; + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + // restore col id and advance to the next set of scales + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + // Local prefill A2 + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + + // Global prefetch A1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Global prefetch B2 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // A1 * B1 + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[mfma_reg_buf] + [Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[mfma_reg_buf] + [Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); // KRepeat + }); // NRepeat + }); // MRepeat + + // Local prefetch A2 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * + xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; // LoopFunc + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales 2 + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(I1)); + + // Prefetch b_scales 2 + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(I1)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + + auto b_scale_thread_buf_copy_up = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy_up); + + b_scale_thread_bufs_up(I1)(Number{}) = + b_scale_thread_buf_copy_up[Number<0>{}]; + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + + // Local prefill A2 + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + // Global prefetch B2 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + + // A1 * B1 + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); // KRepeat + }); // NRepeat + }); // MRepeat + + // Local prefetch A2 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + // A2 * B2 + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I1][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I1][Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); // KRepeat + }); // NRepeat + }); // MRepeat + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); // KRepeat + }); // NRepeat + }); // MRepeat + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from a_scale_grid to a_scale_thread + static constexpr auto a_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from b_scale_grid to b_scale_thread_buf + static constexpr auto b_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + + protected: + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + // using Base::b_thread_desc_; + using Base::c_thread_desc_; + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp new file mode 100644 index 0000000000..59b2619416 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXBPreshufflePipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + ; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp new file mode 100644 index 0000000000..c3b54df7c8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -0,0 +1,813 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::KThreadChunk; + + using Base::APackedSize; + using Base::BPackedSize; + using Base::ComputePackedSize; + + using AccType = typename Base::AccType; + using Tuple4 = typename Base::Tuple4; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = + make_static_buffer( + a_scale_thread_desc_copy.GetElementSpaceSize()); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_copy, + make_tuple(I0, I0), + a_scale_thread_buf_copy); + + a_scale_thread_buf(I0)(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + + // Prefetch b_scales to buf 0 + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(I0)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + __builtin_amdgcn_sched_barrier(0); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Prefetch a_scales to buf 1 + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = + make_static_buffer( + a_scale_thread_desc_copy.GetElementSpaceSize()); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_copy, + make_tuple(I0, I0), + a_scale_thread_buf_copy); + + a_scale_thread_buf(I1)(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + + // Prefetch b_scales to buf 1 + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(I1)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert( + 0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[mfma_reg_buf] + [Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + // a thread copy + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * + xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Prefetch a_scales + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(mfma_reg_buf)); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0)); + + // Prefetch b_scales + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_bufs(mfma_reg_buf)(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); + }); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + // a thread copy + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I1][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs[I0][Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from a_scale_grid to a_scale_thread + static constexpr auto a_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from b_scale_grid to b_scale_thread_buf + static constexpr auto b_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); + + protected: + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + // using Base::b_thread_desc_; + using Base::c_thread_desc_; + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp new file mode 100644 index 0000000000..ec0628ca20 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -0,0 +1,1032 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using Base::APackedSize; + using Base::BPackedSize; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + template + __host__ __device__ static constexpr auto + MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_M3_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<2>{}); + constexpr index_t M3 = TileDesc_M0_M1_M2_M3_K{}.GetLength(Number<3>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_M3_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4, 5, 6>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_m3_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_m3_k); + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_n2_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch A1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Prefetch a_scales to buf 0 + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales 1 + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); // vmem->vgpr-> lds0 + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + // Prefetch a_scales to buf 1 + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales 1 + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill A2 + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(scale_mem_buf)); + + // Global prefetch A1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Global prefetch B2 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_n2_k0_k1, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // A1 * B1 + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // Local prefetch A2 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf.At(scale_mem_buf), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; // LoopFunc + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // Local prefill A2 + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + // Global prefetch B2 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_n2_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + // A1 * B1 + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + // Local prefetch A2 + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + // A2 * B2 + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + // b_thread_vec.template AsType()(ik) = + // b_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + type_convert(ck::float2_t(1.0)); + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; + + static constexpr BTileDesc b_block_desc_n0_n1_n2_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp index 074b5873ee..c6966011b4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp" @@ -171,26 +172,54 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); if constexpr(std::is_same::value) { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}; + } } else { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp new file mode 100644 index 0000000000..818439fddf --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp" +namespace ck { + +template +constexpr auto BlockGemmBlockScaleBPreshufflePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } +#if 0 + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v2< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } +#endif + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); + return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp new file mode 100644 index 0000000000..8e2922e2ce --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp @@ -0,0 +1,864 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::MWaves; + using Base::NWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + // __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + // __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( + [&](auto t) { + using pk_fma_type = + typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec + .template AsType()[Number<0>{}], + c_thread_buf + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp new file mode 100644 index 0000000000..cc4c5a2c36 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp @@ -0,0 +1,1090 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t LocalPrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - (LocalPrefetchStages - 1))); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - (LocalPrefetchStages - 1))); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - (LocalPrefetchStages - 1))) * + ((num_total_stages - (LocalPrefetchStages - 1))); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - (LocalPrefetchStages - 1) - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) && + (imfma < (num_mfma_perstage - 1))) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac + }); + // Scale load, 1B + if constexpr(i.value == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // __builtin_amdgcn_sched_barrier(0); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) && + (imfma < (num_mfma_perstage - 1))) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac + }); + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // __builtin_amdgcn_sched_barrier(0); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, (LocalPrefetchStages - 1), 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) && + (imfma < (num_mfma_perstage - 1))) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac + }); + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // __builtin_amdgcn_sched_barrier(0); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + // StaticallyIndexedArray{}> c_scale_thread_bufs; + + // Global prefetch A1 B1, AScale1 BScale1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_bufs(I0)); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // Global prefetch A2, AScale2 BScale2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + +#if 1 + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_bufs(I0)); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); +#endif + // Initialize C + c_thread_buf.Clear(); + + // Double register buffer for non-scaled gemm computation + // 1. Reduce register pressure + // 2. Decouple the dependency between mfma instruction and scale-fma instruction following. + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + +#if 1 + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + + // Fill first mfma buffer + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); +#endif + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(local_read_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_bufs(local_read_buf)); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + constexpr auto b_local_buf_id = + Number{}; + + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs + [b_local_buf_id][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + // We have to 1 stage early sync the lds for workaround the compiler + // limitation + if constexpr(m0.value == (MRepeat - LocalPrefetchStages - 1)) + { + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= (MRepeat - LocalPrefetchStages) + ? local_read_buf + : mfma_reg_buf; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages + + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; + }); + + // We need new compiler to enable this feature + // HotLoopScheduler(); + // __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + + constexpr auto b_local_buf_id = + Number<0 ^ ((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{}; + + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[b_local_buf_id][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - LocalPrefetchStages)) + { + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= (MRepeat - LocalPrefetchStages) ? I1 : I0; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + }); + + // HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + + if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1)))) + { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + } + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + + if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1)))) + { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + } + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp new file mode 100644 index 0000000000..1608506b40 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp @@ -0,0 +1,1036 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::MWaves; + using Base::NWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = + HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 64) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + // __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf_up = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf_up = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf_up); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // __builtin_amdgcn_sched_barrier(0); + + constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; + }); + }); + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf_up); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + StaticBufferTupleOfVector + c_thread_buf_per_scale_up; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + // __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference( + Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( + [&](auto t) { + using pk_fma_type = + typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec + .template AsType()[Number<0>{}], + c_thread_buf + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale_up + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up + .template AsType()[Number<0>{}], + c_thread_buf_up + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; + }); + }); + }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf_up); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp new file mode 100644 index 0000000000..30d6d4f812 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp @@ -0,0 +1,1203 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + // Scale load, 1B + if constexpr(i.value == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Scale load, 1A + if constexpr(imfma == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf_up = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + // StaticallyIndexedArray{}> c_scale_thread_bufs; + + // Global prefetch A1 B1, AScale1 BScale1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf_up(m0) = + a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // Global prefetch A2, AScale2 BScale2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + // Double register buffer for non-scaled gemm computation + // 1. Reduce register pressure + // 2. Decouple the dependency between mfma instruction and scale-fma instruction following. + StaticBufferTupleOfVector + c_thread_buf_per_scale; + StaticBufferTupleOfVector + c_thread_buf_per_scale_up; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(local_read_buf)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(local_read_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(local_read_buf)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up + .template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; + c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs_up[mfma_reg_buf][I0]; + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + c_scale_thread_buf_up(m0) = + a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp new file mode 100644 index 0000000000..e7c061bd97 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp" +namespace ck { + +template +constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } +#if 0 + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v2< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } +#endif + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp new file mode 100644 index 0000000000..598b69cd61 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -0,0 +1,854 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::MWaves; + using Base::NWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + // __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + // __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( + [&](auto t) { + using pk_fma_type = + typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec + .template AsType()[Number<0>{}], + c_thread_buf + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp new file mode 100644 index 0000000000..6db02d1dd7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -0,0 +1,1070 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + // Scale load, 1B + if constexpr(i.value == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Scale load, 1A + if constexpr(imfma == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + // StaticallyIndexedArray{}> c_scale_thread_bufs; + + // Global prefetch A1 B1, AScale1 BScale1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // Global prefetch A2, AScale2 BScale2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Double register buffer for non-scaled gemm computation + // 1. Reduce register pressure + // 2. Decouple the dependency between mfma instruction and scale-fma instruction following. + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + +#if 0 + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + + // Fill first mfma buffer + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); +#endif + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(local_read_buf)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(local_read_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp new file mode 100644 index 0000000000..f2508d9cfa --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp @@ -0,0 +1,1361 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2 * 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num * 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_b) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // A + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // B0/B1 + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + BBlockBuffer& b_block_buf_up, + const BBlockTransferStep& b_block_copy_step, + // C + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // A scale + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + // B0/B1 scale + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf_up = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf_up = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up); + b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(scale_mem_buf)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + b_thread_vec_up.template AsType()( + ik) = b_thread_buf_up + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up + .template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I1)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp new file mode 100644 index 0000000000..84b0eebb31 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXNBSPipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(GUFusion) + { + return nullptr; + } + else + { + return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1{}; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}; + } + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp new file mode 100644 index 0000000000..32f6248543 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp @@ -0,0 +1,664 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp new file mode 100644 index 0000000000..b48e464fee --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp @@ -0,0 +1,1126 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_b) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index a4038e9543..a7d22066ac 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -532,6 +532,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}([&](auto m0) { a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, @@ -560,8 +563,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale +struct DeviceMoEGemmMXBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef CK_CODE_GEN_RTC + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +#endif +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp index 7171715250..abf49bdab2 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,6 +60,49 @@ struct DeviceGemmMultipleD_ABScale : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp new file mode 100644 index 0000000000..c446ca59ea --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp @@ -0,0 +1,507 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle + : public DeviceGemmMultipleD_BlockScale_BPreshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // unconditional 2 to remove agpr usage + constexpr index_t minimum_occupancy = 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != + // KPerBlock) + // { + // return false; + // } + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + // Padding to release this restriction + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemmBlockScale + : public DeviceGemmMultipleD_BlockScale_BPreshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = GridwiseMoeGemmBlockScale< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / + 4 * (1 + GridwiseGemm::NWave); + constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / + 4 * (2) * (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) / + BlockSize / 4 * (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_total = + estimated_reg_a + estimated_reg_b + estimated_reg_c; + + constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#if 1 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + } +#endif + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + // index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, // KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceMoeGEmm" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = + GridwiseMoeGemmMX; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / + APackedSize / BlockSize / 4 * + (1 + GridwiseGemm::NWave); + constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / + BPackedSize / BlockSize / 4 * (2) * + (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) / + BlockSize / 4 * (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_total = + estimated_reg_a + estimated_reg_b + estimated_reg_c; + + constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v3 support now"); + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceMoeGEmmMx" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = + GridwiseMoeGemmMXBNS; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // TODO: Check if this is the right algorithm for minimum_occupancy + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v3 support now"); + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceMoeGEmmMx" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8) || - ((is_same::value || is_same::value) && - lcm_AK1_BK1 < 32)) + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; static constexpr auto is_scale_mfma = false; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index cb22f99fc2..3eb0f986b3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -168,9 +168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8) || - ((is_same::value || is_same::value) && - lcm_AK1_BK1 < 32)) + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; static constexpr auto is_scale_mfma = false; @@ -1192,7 +1190,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -1200,7 +1197,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); // B matrix in LDS memory, dst of blockwise copy - // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy @@ -1629,7 +1625,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -1637,7 +1632,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); // B matrix in LDS memory, dst of blockwise copy - // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp new file mode 100644 index 0000000000..322cd3d162 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -0,0 +1,2080 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( + typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( + typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle +{ + using AScaleType = float; + using BScaleType = float; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack / KGroup); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + const AScaleType* p_a_scale_grid_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AScaleType* p_a_scale_grid; + const BScaleType* p_b_scale_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM))); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<1, 0>, + 0, + 1, + 1, + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + true>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(MWaves * MPerXdl, 0), + make_multi_index(-MPerBlock, 0), + make_multi_index(-MPerBlock, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM))); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<1, 0>, + 0, + 1, + 1, + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + true>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(MWaves * MPerXdl, 0), + make_multi_index(-MPerBlock, 0), + make_multi_index(-MPerBlock, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index a0e716ba8e..223670e3bc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -1221,7 +1221,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle } } } -#if 0 // check gridwise gemm pipeline const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); @@ -1232,7 +1231,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle return false; } } -#endif // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -2123,6 +2121,58 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle n_thread_data_on_block_idx[I3]), ck::tensor_operation::element_wise::PassThrough{}}; + // calculate C grid descriptor + constexpr auto DWORD_BYTES = 4; + constexpr auto atomic_vector_size = DWORD_BYTES / sizeof(CDataType); + + constexpr auto CShuffleBlockTransferClusterLengths = [&]() { + if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set) + { + return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}; + } + // Atomic operation + else + { + return generate_sequence_v2( + [&](auto i) { + if constexpr(i == 3) + { + return Number< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} + .At(i) * + CShuffleBlockTransferScalarPerVector_NPerBlock / + atomic_vector_size>{}; + } + else if constexpr(i == 1) + { + return Number< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} + .At(i) / + CShuffleBlockTransferScalarPerVector_NPerBlock * + atomic_vector_size>{}; + } + else + { + return Number< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} + .At(i)>{}; + } + }, + Number<4>{}); + } + }(); + + constexpr auto CShuffleBlockTransferScalarPerVector = [&]() { + if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set) + { + return CShuffleBlockTransferScalarPerVector_NPerBlock; + } + else + { + return atomic_vector_size; + } + }(); + // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< ThisThreadBlock, // ThreadGroup @@ -2132,15 +2182,15 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + decltype(CShuffleBlockTransferClusterLengths), Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, CShuffleDataType, // typename SrcData, CDataType, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector, // index_t ScalarPerVector, true, // bool ThreadTransferSrcResetCoordinateAfterRun, false> // bool ThreadTransferDstResetCoordinateAfterRun> {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index a083293485..62d94c0bf8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -183,27 +183,28 @@ struct GridwiseMoeGemm static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); - static constexpr bool is_single_rate_mfma = - (((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8) || - ((is_same::value || is_same::value) && - lcm_AK1_BK1 < 32)) - ? true - : false; - static constexpr auto is_scale_mfma = false; - static constexpr auto mfma = MfmaSelector{}; - static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); - static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); + + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); + + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; static constexpr index_t SortedTileSize = MPerBlock; @@ -262,7 +263,7 @@ struct GridwiseMoeGemm } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPack / KGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -404,7 +405,7 @@ struct GridwiseMoeGemm __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1314,7 +1315,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1360,7 +1361,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -1899,7 +1900,8 @@ struct GridwiseMoeGemm const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; @@ -1908,12 +1910,13 @@ struct GridwiseMoeGemm const auto block_mn = [&]() -> std::pair { if constexpr(NSwizzle) { - const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; - const index_t prefix_block = ecnt_prefix * problem.NBlock; - const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; - const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; - const index_t bid_new = blockIdx.x - prefix_block; - const index_t nid = __builtin_amdgcn_readfirstlane( + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); @@ -1924,9 +1927,9 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; - const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); @@ -1938,11 +1941,9 @@ struct GridwiseMoeGemm constexpr auto AMRepeats = MPerBlock / AMThreads; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; - if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || - token0 >= problem.NumTokens) + if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray - gather_offsets; //= p_sorted_token_ids[token_pos]; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1952,7 +1953,8 @@ struct GridwiseMoeGemm } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -2025,7 +2027,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2042,24 +2044,76 @@ struct GridwiseMoeGemm static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { @@ -2087,6 +2141,185 @@ struct GridwiseMoeGemm constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -2184,18 +2417,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - // if(i.value == 1) - // { - // ptr_ += - // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : - // 1); - // } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -2271,7 +2494,6 @@ struct GridwiseMoeGemm auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -2310,7 +2532,7 @@ struct GridwiseMoeGemm block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; + IndexType token_offset = fused_token & 0xffffff; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -2323,7 +2545,7 @@ struct GridwiseMoeGemm // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, + c_thread_buf_fp32, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp new file mode 100644 index 0000000000..fbfe2509ff --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -0,0 +1,2668 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_gemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseMoeGemmBlockScale +{ + using AScaleType = float; + using BScaleType = float; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + return std::make_tuple(gridx, gridy, 1); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack / KGroup); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + const AScaleType* p_a_scale_grid_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AScaleType* p_a_scale_grid; + const BScaleType* p_b_scale_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize, + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens + : problem.NumTokens * problem.TopK, + ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockK)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + // get each thread's offset in the scale tensor + // A scale + const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM; + + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray scale_gather_offsets; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { + const index_t fused_token = + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scale_gather_offsets(m0) = + token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK); + }); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2_gather, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false, + MXdlPerWave>( + a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + const BScaleType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + c_thread_buf_up, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; + + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion, elementwise + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + const DDataType* ptr_ = p_ds_grid[i]; + // hack logic here to support different kind of strides. todo fix it. + // ascale t, 1; bscale E, N, 1, move ptr to E + return make_dynamic_buffer( + ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = token_offset * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens + : problem.NumTokens * problem.TopK, + ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || + token0 >= problem.NumTokens) + return; + StaticallyIndexedArray + gather_offsets; //= p_sorted_token_ids[token_pos]; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockK)); + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // scale + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + // get each thread's offset in the scale tensor + // A scale + const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM; + + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray scale_gather_offsets; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { + const index_t fused_token = + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scale_gather_offsets(m0) = static_cast(token_offset) * + math::integer_divide_ceil(problem.K, ScaleBlockK); + }); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2_gather, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false, + MXdlPerWave>( + a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + const BScaleType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_scale_thread_desc, + c_thread_buf, + c_thread_buf_up, + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + b_scale_thread_slice_copy_step, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_scale_thread_desc, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop); + } + + // shuffle C and write out + { + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; + + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion, elementwise + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } + + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray + scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = token_offset * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp new file mode 100644 index 0000000000..fc156a878f --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -0,0 +1,2652 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid, + karg.p_a_scale_grid, + karg.p_b_grid, + karg.p_b_scale_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseMoeGemmMX +{ + using LDSTypeA = ADataType; + using LDSTypeB = BDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + + static constexpr index_t KGroup = 1; // mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + // static_assert(KGroup == 2, ""); + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave; + + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + + return std::make_tuple(gridx, gridy, 1); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack / KGroup); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber), + make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber, + NXdlPack * K0 * NkSwizzleNumber, + K0 * NkSwizzleNumber, + NkSwizzleNumber, + I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + GemmSpec != GemmSpecialization::Default), + "f4x2_pk_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + struct Problem + { + __host__ __device__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SSCaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / APackedSize, + StrideScaleB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead; + } + + // Calculate A scale offset + if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); + } + else if constexpr(is_same_v) + { + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + b_scale_k_split_offset = + k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; + index_t b_scale_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K / APackedSize; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = + __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + // A, B scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<1, 2, 0, 3>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // mul scales + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 1; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<1, 2, 0, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // mul scales + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0 / MXdlPack, + n0 / NXdlPack, + m0 % MXdlPack, + n0 % NXdlPack, + m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp new file mode 100644 index 0000000000..7238917920 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -0,0 +1,2849 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +#if 0 +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid, + karg.p_a_scale_grid, + karg.p_b_grid, + karg.p_b_scale_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} +#endif + +template +struct GridwiseMoeGemmMXBNS +{ + using LDSTypeA = ADataType; + using LDSTypeB = BDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); + + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + + return std::make_tuple(gridx, gridy, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + GemmSpec != GemmSpecialization::Default), + "f4x2_pk_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + struct Problem + { + __host__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / BPackedSize, + StrideScaleB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead; + } + + // Calculate A scale offset + if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); + } + else if constexpr(is_same_v) + { + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + b_scale_k_split_offset = + k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; + index_t b_scale_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return b_lds_block_desc_permuted; + } + else // RowMajor B + { + constexpr auto WaveSize = 64; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + if constexpr(IsInputGemm) + { + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)) * + 2, + c_block_size * sizeof(CShuffleDataType)); + } + else + { + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + + return false; + } + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // Gride buffer creation + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + // A, B scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + auto b_block_buf_up = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride, + b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy_up = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + // A + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + // Gate and Up + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_buf_up, + b_block_slice_copy_step, + // C + c_thread_buf, + c_thread_buf_up, + // A scale + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + // Gate and Up scale + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, // A + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, // B + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, // B scale + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + // mul scales + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; + + if constexpr(MulRoutedWeight) + { + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + + /*float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + //up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = up;*/ + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) + // per shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + +#if 0 + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<1, 2, 0, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // mul scales + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0 / MXdlPack, + n0 / NXdlPack, + m0 % MXdlPack, + n0 % NXdlPack, + m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +#endif +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index c17b88ccea..4e4c92de40 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -580,11 +580,6 @@ struct ThreadwiseTensorSliceTransfer_v2_gather }); }); - // printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n", - // blockIdx.y, - // threadIdx.x, - // dst_buf(Number<0>{})); - // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) { diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 7da353d9ad..1dd766eca0 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1146,15 +1146,6 @@ struct MfmaSelector #endif } - // Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) - // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 - // TODO: explore optimization opportunity by using new mfma instructions on gfx950 - template <> - constexpr auto GetMfma() - { - return MfmaInstr::mfma_f32_32x32x16f8f8; - } - template <> constexpr auto GetMfma() { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 9a28c5f332..56da5c1dc8 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -881,11 +881,6 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB> #endif } }; -#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 1 - -#ifndef BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS -#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 0 -#endif template struct intrin_mfma_scale_f32_16x16x128f8f6f4; @@ -893,48 +888,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4; template struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> { - -#define V_MFMA_SCALE_F32_16X16X128_F8F6F4(OPF_F8F6F4_CTRL_A, \ - OPF_F8F6F4_CTRL_B, \ - F8F6F4_VEC_TYPE_A, \ - F8F6F4_VEC_TYPE_B, \ - OPSEL_A_L, \ - OPSEL_A_H, \ - OPSEL_B_L, \ - OPSEL_B_H) \ - if constexpr((OpselA == 1 * OPSEL_A_L + 2 * OPSEL_A_H) && \ - (OpselB == 1 * OPSEL_B_L + 2 * OPSEL_B_H)) \ - asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 %0, %1, %2, %3, %4, %5 " \ - "op_sel:[" #OPSEL_A_L "," #OPSEL_A_H "] " \ - "op_sel_hi:[" #OPSEL_B_L "," #OPSEL_B_H "] " \ - "cbsz:" #OPF_F8F6F4_CTRL_A " blgp:" #OPF_F8F6F4_CTRL_B \ - : "+v"(reg_c.template AsType()(Number<0>{})) \ - : "v"(bit_cast(reg_a)), \ - "v"(bit_cast(reg_b)), \ - "v"(reg_c.template AsType()[Number<0>{}]), \ - "v"(scale_a), \ - "v"(scale_b)) -#define BOOL4_CASES(F) \ - do \ - { \ - F(0, 0, 0, 0); \ - F(0, 0, 0, 1); \ - F(0, 0, 1, 0); \ - F(0, 0, 1, 1); \ - F(0, 1, 0, 0); \ - F(0, 1, 0, 1); \ - F(0, 1, 1, 0); \ - F(0, 1, 1, 1); \ - F(1, 0, 0, 0); \ - F(1, 0, 0, 1); \ - F(1, 0, 1, 0); \ - F(1, 0, 1, 1); \ - F(1, 1, 0, 0); \ - F(1, 1, 0, 1); \ - F(1, 1, 1, 0); \ - F(1, 1, 1, 1); \ - } while(0) - template __device__ static void Run(const f8x32_t& reg_a, const int32_t& scale_a, @@ -943,7 +896,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> FloatC& reg_c) { #if defined(__gfx950__) -#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( @@ -956,11 +908,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> scale_a, OpselB, // OPSEL scale_b); -#else -#define f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 0, int32x8_t, int32x8_t, __VA_ARGS__) - BOOL4_CASES(f8_cases); -#undef f8_cases -#endif #else ignore = reg_a; ignore = scale_a; @@ -978,7 +925,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> FloatC& reg_c) { #if defined(__gfx950__) -#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( @@ -991,10 +937,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> scale_a, OpselB, // OPSEL scale_b); -#else -#define bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 1, int32x8_t, int32x8_t, __VA_ARGS__) - BOOL4_CASES(bf8_cases); -#endif #else ignore = reg_a; ignore = scale_a; @@ -1012,7 +954,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> FloatC& reg_c) { #if defined(__gfx950__) -#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( @@ -1025,11 +966,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> scale_a, OpselB, // OPSEL scale_b); -#else -#define f8bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 1, int32x8_t, int32x8_t, __VA_ARGS__) - BOOL4_CASES(f8bf8_cases); -#undef f8bf8_cases -#endif #else ignore = reg_a; ignore = scale_a; @@ -1047,7 +983,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> FloatC& reg_c) { #if defined(__gfx950__) -#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( @@ -1060,11 +995,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> scale_a, OpselB, // OPSEL scale_b); -#else -#define bf8f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 0, int32x8_t, int32x8_t, __VA_ARGS__) - BOOL4_CASES(bf8f8_cases); -#undef bf8f8_cases -#endif #else ignore = reg_a; ignore = scale_a; @@ -1141,24 +1071,13 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> } template - __device__ static void - Run(const f4x32_t& reg_a, // misalignment between pk_f4_t, 32 and f4_t, 32 - const int32_t scale_a, - const f4x32_t& reg_b, - const int32_t scale_b, - FloatC& reg_c) + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) { -#if 0 - if(get_thread_local_1d_id()){ - printf("Tid: %03d, Scale A: %08x, Scale B: %08x, OpSelA: %d, OpSelB: %d\n", - get_thread_local_1d_id(), - *reinterpret_cast(&scale_a), *reinterpret_cast(&scale_b), - OpselA, OpselB); - } -#endif #if defined(__gfx950__) -#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS int32x4_t arg_a = bit_cast(reg_a); int32x4_t arg_b = bit_cast(reg_b); using arg_type = int32x8_t; @@ -1173,11 +1092,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> scale_a, OpselB, // OPSEL scale_b); -#else -#define f4_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(4, 4, int32x4_t, int32x4_t, __VA_ARGS__) - BOOL4_CASES(f4_cases); -#undef f4_cases -#endif #else ignore = reg_a; ignore = scale_a; @@ -1186,9 +1100,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> ignore = reg_c; #endif } -#undef BOOL4_CASES -#undef V_MFMA_SCALE_F32_16X16X128_F8F6F4 -}; // namespace ck +}; template struct intrin_mfma_f32_16x16x128f8f6f4; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index ad9bb45158..51da18cd2b 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -165,17 +165,6 @@ inline constexpr bool is_native_type() is_same::value || is_same::value || is_same::value; } -template -struct is_f8f6f4 -{ - static constexpr bool value = - is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || is_same_v; -}; -template -inline constexpr bool is_f8f6f4_v = is_f8f6f4::value; - // scalar_type template struct scalar_type; diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 2b247cc02a..45d443ae49 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -87,6 +87,19 @@ __device__ static bool is_thread_local_1d_id_idx() return ((tid == Ids) || ...); } +// Use `CK_PRINT()` to inspect values of type T1, T2, ... +// Use `CK_PRINT()` to inspect constexpr values of val1, val2, ... of the same type +// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());` +// Set BUILD_DEV to OFF to avoid enabling Werror +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} + } // namespace debug } // namespace ck diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 049221cea1..0891a7ccf4 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1036,11 +1036,11 @@ struct vector_type()>> StaticallyIndexedArray d32x4_; StaticallyIndexedArray d64x2_; StaticallyIndexedArray d128x1_; - } data_; + } data_ = {d128_t{0}}; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } template __host__ __device__ constexpr const auto& AsType() const @@ -1164,11 +1164,11 @@ struct vector_type()>> StaticallyIndexedArray d64x4_; StaticallyIndexedArray d128x2_; StaticallyIndexedArray d256x1_; - } data_; + } data_ = {d256_t{0}}; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } template __host__ __device__ constexpr const auto& AsType() const @@ -2228,7 +2228,9 @@ using f6x32_t = typename vector_type::type; using bf6x16_t = typename vector_type::type; using bf6x32_t = typename vector_type::type; +// e8m0 using e8m0x4_bexp_t = typename vector_type::type; + // pack int4 using pk_i4x2_t = typename vector_type::type; using pk_i4x4_t = typename vector_type::type; diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index 16213173f3..ef8b5a435c 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -6,6 +6,7 @@ #include "ck/utility/functional.hpp" #include "ck/utility/sequence.hpp" #include "ck/utility/tuple.hpp" +#include "ck/utility/type.hpp" namespace ck { @@ -107,7 +108,7 @@ struct identity template __host__ __device__ constexpr T&& operator()(T&& arg) const noexcept { - return forward(arg); + return ck::forward(arg); } }; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp new file mode 100644 index 0000000000..eedd687bde --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeGemm1BlockScale : public device::BaseOperator +{ + // Argument + static constexpr auto ActivationType = ActivationType_; + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& b_e_n_k, + const Tensor& d2, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + max_token_id_{max_token_id}, + sorted_tile_size_{sorted_tile_size}, + a_t_k_{a_t_k}, + b_e_n_k_{b_e_n_k}, + d2_{d2}, + c_t_k_n_{c_t_k_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& sorted_token_ids_; + const Tensor& expert_ids_; + const Tensor& max_token_id_; + index_t sorted_tile_size_; + const Tensor& a_t_k_; + const Tensor& b_e_n_k_; + const Tensor& d2_; + Tensor& c_t_k_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeGemm1BlockScale::Argument; + + float Run(const Argument& arg) + { + static_assert(ActivationType < 2, "Not supported activation type"); + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + AccDataType v_acc_up{0}; + ComputeTypeB v_b_up{0}; + AccDataType v_acc{0}; + + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + + const int t = arg.sorted_token_ids_(m) & 0xffffff; + const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; + D2DataType v_topk_w = arg.d2_(m, 0); // expert + if(t < token_cnt) + { + for(int k = 0; k < K; ++k) + { + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.a_t_k_(t, k).data; + uint8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_a = i4_to_f32_gfx9(i4); +#else + v_a = i4 - 8; +#endif + } + else + { + arg.a_element_op_(v_a, arg.a_t_k_(t, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data; + uint8_t i4 = 0; + uint8_t i4_up = 0; + if(k % 2 == 1) + { + i4 = (i4x2 >> 0) & 0xf; + i4_up = (i4x2_up >> 0) & 0xf; + } + else + { + i4 = (i4x2 >> 4) & 0xf; + i4_up = (i4x2_up >> 4) & 0xf; + } +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_b = i4_to_f32_gfx9(i4); + v_b_up = i4_to_f32_gfx9(i4_up); +#else + v_b = i4 - 8; + v_b_up = i4_up - 8; +#endif + } + else + { + arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + v_acc_up += ck::type_convert(v_a) * + ck::type_convert(v_b_up); + } + CDataType v_c{0}; + CDataType v_c_up{0}; + if constexpr(MulRoutedWeight) + { + v_acc *= v_topk_w; + v_acc_up *= v_topk_w; + } + + arg.c_element_op_(v_c, v_acc); + arg.c_element_op_(v_c_up, v_acc_up); + if constexpr(ActivationType == 1) + { + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Silu{}(v_c, v_c); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + else if constexpr(ActivationType == 0) + { + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Gelu{}(v_c, v_c); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + } + }; + + const ck::index_t max_token_id = arg.max_token_id_(0); + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& b_e_n_k, + const Tensor& d2, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, + expert_ids, + max_token_id, + sorted_tile_size, + a_t_k, + b_e_n_k, + d2, + c_t_k_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm1BlaockScale" + << std::endl; + // clang-format on + + return str.str(); + } + + static float i4_to_f32_gfx9(uint8_t i4) + { + static std::unordered_map u = {{0b1000, -0.5000f}, + {0b1001, -0.4375f}, + {0b1010, -0.3750f}, + {0b1011, -0.3125f}, + {0b1100, -0.2500f}, + {0b1101, -0.1875f}, + {0b1110, -0.1250f}, + {0b1111, -0.0625f}, + {0b0, +0.0000f}, + {0b1, +0.0625f}, + {0b10, +0.1250f}, + {0b11, +0.1875f}, + {0b100, +0.2500f}, + {0b101, +0.3125f}, + {0b110, +0.3750f}, + {0b111, +0.4375f}}; + + return u[i4]; + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 5c932fcb18..583d704040 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -156,9 +156,14 @@ struct ReferenceMoeGemm2 : public device::BaseOperator } }; - const ck::index_t max_token_id = arg.max_token_id_(0); - make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, arg.c_t_n_.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); + const std::size_t max_token_id = arg.max_token_id_(0); + // avoid parallelizing over the m dim to prevent data race + make_ParallelTensorFunctor( + [&](auto n) { + for(std::size_t m = 0; m < max_token_id; ++m) + f_mk_kn_mn(m, n); + }, + arg.c_t_n_.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); return 0; } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp new file mode 100644 index 0000000000..a10ef88557 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeGemm2BlockScale : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k_k, + const Tensor& b_e_n_k, + const Tensor& d2, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + max_token_id_{max_token_id}, + sorted_tile_size_{sorted_tile_size}, + a_t_k_k_{a_t_k_k}, + b_e_n_k_{b_e_n_k}, + d2_{d2}, + c_t_n_{c_t_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& sorted_token_ids_; + const Tensor& expert_ids_; + const Tensor& max_token_id_; + index_t sorted_tile_size_; + const Tensor& a_t_k_k_; + const Tensor& b_e_n_k_; + const Tensor& d2_; + Tensor& c_t_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeGemm2BlockScale::Argument; + + float Run(const Argument& arg) + { + arg.c_t_n_.SetZero(); + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_t_k_k_.mDesc.GetLengths()[2]; + AccDataType v_acc{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m) & 0xffffff; + const int topk_id = arg.sorted_token_ids_(m) >> 24; + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0]; + AccDataType v_topk_w = arg.d2_(m, 0); // expert + + if(t < token_cnt) + { + for(int k = 0; k < K; ++k) + { + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.a_t_k_(t, topk_id, k).data; + uint8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_a = i4_to_f32_gfx9(i4); +#else + v_a = i4 - 8; +#endif + } + else + { + arg.a_element_op_(v_a, arg.a_t_k_k_(t, topk_id, k)); + } + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_b = i4_to_f32_gfx9(i4); +#else + v_b = i4 - 8; +#endif + } + else + { + arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + CDataType v_c{0}; + if constexpr(MulRoutedWeight) + { + arg.c_element_op_(v_c, v_acc, v_topk_w); + } + else + { + arg.c_element_op_(v_c, v_acc, 1.f); + } + arg.c_t_n_(t, n) += v_c; + } + }; + + const std::size_t max_token_id = arg.max_token_id_(0); + // avoid parallelizing over the m dim to prevent data race + make_ParallelTensorFunctor( + [&](auto n) { + for(std::size_t m = 0; m < max_token_id; ++m) + f_mk_kn_mn(m, n); + }, + arg.c_t_n_.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k_k, + const Tensor& b_e_n_k, + const Tensor& d2, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, + expert_ids, + max_token_id, + sorted_tile_size, + a_t_k_k, + b_e_n_k, + d2, + c_t_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm2" + << std::endl; + // clang-format on + + return str.str(); + } + +#if CK_USE_PK4_LAYOUT_SHUFFLE + static float i4_to_f32_gfx9(uint8_t i4) + { + static std::unordered_map u = {{0b1000, -0.5000f}, + {0b1001, -0.4375f}, + {0b1010, -0.3750f}, + {0b1011, -0.3125f}, + {0b1100, -0.2500f}, + {0b1101, -0.1875f}, + {0b1110, -0.1250f}, + {0b1111, -0.0625f}, + {0b0, +0.0000f}, + {0b1, +0.0625f}, + {0b10, +0.1250f}, + {0b11, +0.1875f}, + {0b100, +0.2500f}, + {0b101, +0.3125f}, + {0b110, +0.3750f}, + { 0b111, + +0.4375f }}; + + return u[i4]; + } +#endif +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp new file mode 100644 index 0000000000..4dd331bc19 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeMXGemm1 : public device::BaseOperator +{ + // Argument + static constexpr auto ActivationType = ActivationType_; + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& a_t_k_scale, + const Tensor& b_e_n_k, + const Tensor& b_e_n_k_scale, + const Tensor& d2, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + max_token_id_{max_token_id}, + sorted_tile_size_{sorted_tile_size}, + a_t_k_{a_t_k}, + a_t_k_scale_{a_t_k_scale}, + b_e_n_k_{b_e_n_k}, + b_e_n_k_scale_{b_e_n_k_scale}, + d2_{d2}, + c_t_k_n_{c_t_k_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& sorted_token_ids_; + const Tensor& expert_ids_; + const Tensor& max_token_id_; + index_t sorted_tile_size_; + const Tensor& a_t_k_; + const Tensor& a_t_k_scale_; + const Tensor& b_e_n_k_; + const Tensor& b_e_n_k_scale_; + const Tensor& d2_; + Tensor& c_t_k_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeMXGemm1::Argument; + + float Run(const Argument& arg) + { + static_assert(ActivationType < 2, "Not supported activation type"); + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + arg.c_t_k_n_.SetZero(); + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1]; + AccDataType v_acc{0}; + AccDataType v_acc_up{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + ComputeTypeB v_b_up{0}; + const int t = arg.sorted_token_ids_(m) & 0xffffff; + const int topk_id = arg.sorted_token_ids_(m) >> 24; + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.c_t_k_n_.mDesc.GetLengths()[0]; + D0DataType v_topk_w = arg.d2_(m, 0); // expert + + if(t < token_cnt) + { + for(int k = 0; k < K; ++k) + { + auto a_f4x2 = arg.a_t_k_(t, k).data; + auto a_scale = arg.a_t_k_scale_(t, k / SCALE_BLOCK); + if constexpr(is_same_v) + { + + f4_t f4 = 0; + if(k % 2 == 1) + f4 = (a_f4x2 >> 0) & 0xf; + else + f4 = (a_f4x2 >> 4) & 0xf; + v_a = type_convert(f4) * + type_convert(a_scale); + } + else + { + v_a = type_convert(a_f4x2) * + type_convert(a_scale); + arg.a_element_op_(v_a, v_a); + } + auto b_f4x2 = arg.b_e_n_k_(e, k, n).data; + auto b_f4x2_up = arg.b_e_n_k_(e, k, n + full_n).data; + auto b_scale = arg.b_e_n_k_scale_(e, k / SCALE_BLOCK, n); + auto b_scale_up = arg.b_e_n_k_scale_(e, k / SCALE_BLOCK, n + full_n); + if constexpr(is_same_v) + { + + f4_t f4 = 0; + f4_t f4_up = 0; + if(k % 2 == 1) + { + f4 = (b_f4x2 >> 0) & 0xf; + f4_up = (b_f4x2_up >> 0) & 0xf; + } + else + { + f4 = (b_f4x2 >> 4) & 0xf; + f4_up = (b_f4x2_up >> 4) & 0xf; + } + v_b = type_convert(f4) * + type_convert(b_scale); + v_b_up = type_convert(f4_up) * + type_convert(b_scale_up); + } + else + { + v_b = type_convert(b_f4x2) * + type_convert(b_scale); + v_b_up = type_convert(b_f4x2_up) * + type_convert(b_scale_up); + arg.b_element_op_(v_b, v_b); + arg.b_element_op_(v_b_up, v_b_up); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + v_acc_up += ck::type_convert(v_a) * + ck::type_convert(v_b_up); + } + CDataType v_c{0}; + CDataType v_c_up{0}; + if constexpr(MulRoutedWeight) + { + v_acc *= v_topk_w; + v_acc_up *= v_topk_w; + } + arg.c_element_op_(v_c, v_acc); + arg.c_element_op_(v_c_up, v_acc_up); + if constexpr(ActivationType == 1) + { + tensor_operation::element_wise::Silu{}(v_c, v_c); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + else if constexpr(ActivationType == 0) + { + tensor_operation::element_wise::Gelu{}(v_c, v_c); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + } + }; + + const ck::index_t max_token_id = arg.max_token_id_(0); + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& a_t_k_scale, + const Tensor& b_e_n_k, + const Tensor& b_e_n_k_scale, + const Tensor& d2, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, + expert_ids, + max_token_id, + sorted_tile_size, + a_t_k, + a_t_k_scale, + b_e_n_k, + b_e_n_k_scale, + d2, + c_t_k_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeMxGemm1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp new file mode 100644 index 0000000000..74f25f0f91 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeMXGemm2 : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k_k, + const Tensor& a_t_k_k_scale, + const Tensor& b_e_n_k, + const Tensor& b_e_n_k_scale, + const Tensor& d2, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + max_token_id_{max_token_id}, + sorted_tile_size_{sorted_tile_size}, + a_t_k_k_{a_t_k_k}, + a_t_k_k_scale_{a_t_k_k_scale}, + b_e_n_k_{b_e_n_k}, + b_e_n_k_scale_{b_e_n_k_scale}, + d2_{d2}, + c_t_n_{c_t_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& sorted_token_ids_; + const Tensor& expert_ids_; + const Tensor& max_token_id_; + index_t sorted_tile_size_; + const Tensor& a_t_k_k_; + const Tensor& a_t_k_k_scale_; + const Tensor& b_e_n_k_; + const Tensor& b_e_n_k_scale_; + const Tensor& d2_; + Tensor& c_t_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeMXGemm2::Argument; + + float Run(const Argument& arg) + { + arg.c_t_n_.SetZero(); + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_t_k_k_.mDesc.GetLengths()[2]; + const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1]; + AccDataType v_acc{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m) & 0xffffff; + const int topk_id = arg.sorted_token_ids_(m) >> 24; + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0]; + D0DataType v_topk_w = arg.d2_(m, 0); // expert + + if(t < token_cnt) + { + for(int k = 0; k < K; ++k) + { + if constexpr(is_same_v) + { + auto f4x2 = arg.a_t_k_k_(t, topk_id, k).data; + auto a_scale = arg.a_t_k_k_scale_(t, topk_id, k / SCALE_BLOCK); + + f4_t f4 = 0; + if(k % 2 == 1) + f4 = (f4x2 >> 0) & 0xf; + else + f4 = (f4x2 >> 4) & 0xf; + + v_a = type_convert(f4) * + type_convert(a_scale); + } + else + { + arg.a_element_op_( + v_a, type_convert(arg.a_t_k_k_(t, topk_id, k))); + } + if constexpr(is_same_v) + { + auto f4x2 = arg.b_e_n_k_(e, k, n).data; + auto b_scale = arg.b_e_n_k_scale_(e, k / SCALE_BLOCK, n); + + f4_t f4 = 0; + if(k % 2 == 1) + f4 = (f4x2 >> 0) & 0xf; + else + f4 = (f4x2 >> 4) & 0xf; + + v_b = type_convert(f4) * + type_convert(b_scale); + } + else + { + arg.b_element_op_(v_b, + type_convert(arg.b_e_n_k_(e, k, n))); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + CDataType v_c{0}; + if constexpr(MulRoutedWeight) + { + arg.c_element_op_(v_c, v_acc, 1.f, 1.f, v_topk_w); // hacky, need to fix + } + else + { + arg.c_element_op_(v_c, v_acc, 1.f, 1.f, 1.f); + } + arg.c_t_n_(t, n) += v_c; + } + }; + + const std::size_t max_token_id = arg.max_token_id_(0); + // avoid parallelizing over the m dim to prevent data race + make_ParallelTensorFunctor( + [&](auto n) { + for(std::size_t m = 0; m < max_token_id; ++m) + f_mk_kn_mn(m, n); + }, + arg.c_t_n_.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k_k, + const Tensor& a_t_k_k_scale, + const Tensor& b_e_n_k, + const Tensor& b_e_n_k_scale, + const Tensor& d2, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, + expert_ids, + max_token_id, + sorted_tile_size, + a_t_k_k, + a_t_k_k_scale, + b_e_n_k, + b_e_n_k_scale, + d2, + c_t_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp index a20e608868..a74401ff76 100644 --- a/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,11 +22,16 @@ void add_device_operation_instances(std::vector>& op_ins const auto new_op_instance = std::get(new_op_instances); using NewOpInstance = remove_cvref_t; - - static_assert(std::is_base_of_v, - "wrong! NewOpInstance should be derived from BaseOp"); - - op_instances.push_back(std::make_unique(new_op_instance)); + if constexpr(std::is_same_v) + { + return; // We can use nullptr_t to enable trailing comma + } + else + { + static_assert(std::is_base_of_v, + "wrong! NewOpInstance should be derived from BaseOp"); + op_instances.push_back(std::make_unique(new_op_instance)); + } }); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp new file mode 100644 index 0000000000..ae496e01d3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); + +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); + +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); + +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_BlockScale_BPreshuffle< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>> +{ + using DeviceOp = + DeviceGemmMultipleD_BlockScale_BPreshuffle, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + op_ptrs); + // add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + // op_ptrs); + + add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + // add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + // op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt new file mode 100644 index 0000000000..57cbd725aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +set(GEMM_BLOCKSCALE_WP_INSTANCES) + +list(APPEND GEMM_BLOCKSCALE_WP_INSTANCES + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + ) + +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") + +add_instance_library(device_gemm_blockscale_wp_instance ${GEMM_BLOCKSCALE_WP_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp new file mode 100644 index 0000000000..68bc25dbfb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Memory friendly + // 16x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + //32x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + //48x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 256, 128, 8, 16, 16, 16, 3, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 128, 128, 8, 16, 16, 16, 3, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 64, 128, 8, 16, 16, 16, 3, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 128, 256, 16, 16, 16, 16, 3, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 64, 256, 16, 16, 16, 16, 3, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + //64x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..d745724c35 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..a2e6c4a43c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..91434863fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..cc9a734659 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp index 03ea71883a..40bacb3ee9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp @@ -46,25 +46,26 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances = std::tuple< //#####################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#####################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 4, S<1, 8, 1, 32>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + std::nullptr_t // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp index 1ebb400fdd..2b4c18787a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp @@ -56,7 +56,8 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple< DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + std::nullptr_t // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp index 3645026c60..aa4704530d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + std::nullptr_t // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp index f7ef5562e4..1371e419ea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple< DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3> + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + std::nullptr_t // clang-format on >; diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp new file mode 100644 index 0000000000..53073a6c75 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -0,0 +1,415 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} + +template +bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + ck::index_t Scale_Stride_AM = ((M + ScaleBlockM - 1) / ScaleBlockM); + ck::index_t Scale_Stride_BN = ck::is_same_v + ? ((K + ScaleBlockK - 1) / ScaleBlockK) + : ((N + ScaleBlockN - 1) / ScaleBlockN); + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM, + (K + ScaleBlockK - 1) / ScaleBlockK, + Scale_Stride_AM, + ck::tensor_layout::gemm::ColumnMajor{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_preshuffled_mfma16( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size + Tensor b_preshuffled_mfma32( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size + Tensor b1_k_n(f_host_tensor_descriptor((K + ScaleBlockK - 1) / ScaleBlockK, + (N + ScaleBlockN - 1) / ScaleBlockN, + Scale_Stride_BN, + BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + int total_gemm_needed = + a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() + + a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + preShuffleBuffer(b0_k_n.mData.data(), b_preshuffled_mfma16.mData.data(), N, K, 16); + preShuffleBuffer(b0_k_n.mData.data(), b_preshuffled_mfma32.mData.data(), N, K, 32); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_mfma16(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_mfma32(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b_device_buf_mfma16.ToDevice(b_preshuffled_mfma16.mData.data()); + b_device_buf_mfma32.ToDevice(b_preshuffled_mfma32.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + using DeviceOp = + ck::tensor_operation::device::DeviceGemmMultipleD_BlockScale_BPreshuffle, + ELayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + ck::Tuple<>, + EDataType, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + AElementOp, + BElementOp, + CElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + Tensor c_m_n({M, N}); + Tensor a_m_k({M, K}); + Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / ScaleBlockM, k / ScaleBlockK); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / ScaleBlockK, n / ScaleBlockN); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + int NPerXdl = op_ptr->GetPreShuffleParameters(); + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a0_device_buf.GetDeviceBuffer()), + static_cast(NPerXdl == 16 ? b_device_buf_mfma16.GetDeviceBuffer() + : b_device_buf_mfma32.GetDeviceBuffer()), + std::array{}, + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(e_m_n_device_result.mData.data()); + +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 5e-2; + double atol = 5e-2; + bool current_pass = ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); + pass = pass & current_pass; + if(!current_pass) + { + std::cout << op_ptr->GetTypeString() << " failed" << std::endl; + } + } + else + { +#endif + pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + if(!pass) + { + std::cout << op_ptr->GetTypeString() << " failed" << std::endl; + } +#if defined CK_ENABLE_FP8 + } +#endif + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a0_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b0_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", e_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", e_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run( + argument_ptr.get(), + StreamConfig{ + nullptr, time_kernel, 0, n_warmup, n_iter, rotating_count > 1, rotating_count}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideE = " << StrideE << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_mx_impl.hpp b/profiler/include/profiler/profile_gemm_mx_impl.hpp index 8135bf4475..4df2348700 100644 --- a/profiler/include/profiler/profile_gemm_mx_impl.hpp +++ b/profiler/include/profiler/profile_gemm_mx_impl.hpp @@ -226,6 +226,8 @@ bool profile_gemm_mx_impl(int do_verification, return ck::type_convert(x); }; + using int_distr = std::uniform_int_distribution; + using float_distr = std::uniform_real_distribution; switch(init_method) { case 0: // Initializations for development and debugging @@ -245,21 +247,19 @@ bool profile_gemm_mx_impl(int do_verification, case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] - b_k_n->GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + a_m_k.GenerateTensorDistr(int_distr{-4, 5}); // Z[-4,4] + b_k_n->GenerateTensorDistr(int_distr{-4, 5}); // Z[-4,4] - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + a_m_k_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2} break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0}); + a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); - b_k_n->GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0}); + b_k_n_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); break; } diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 2cfb5581ea..fef09315d5 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -62,6 +62,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) + list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") list(APPEND PROFILER_OPS profile_gemm_mx.cpp) @@ -170,6 +171,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance) list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance) list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance) + list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) diff --git a/profiler/src/profile_gemm_blockscale_wp.cpp b/profiler/src/profile_gemm_blockscale_wp.cpp new file mode 100644 index 0000000000..e6a2fbb8f6 --- /dev/null +++ b/profiler/src/profile_gemm_blockscale_wp.cpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_blockscale_wp_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F16_F16, // 4 + F16_F8_F16, // 5 + F16_F16_F16_F8, // 6 + F8_F8_BF16, // 7 +}; + +enum struct ScaleBlockTile +{ + Tile_128_128_128, // 0 + Tile_1_128_128, // 1 +}; + +#define OP_NAME "gemm_blockscale_wp" +#define OP_DESC "GEMM_BlockScale_WeightPreshuffle" + +int profile_gemm_blockscale_weighpreshuffle(int argc, char* argv[]) +{ + if(argc != 15 && argc != 18) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " + "f16->f8; 7: f8->bf16, " + "comp f8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128]; 1: ScaleBlockM/N/K = " + "[1, 128, 128];\n"); + printf("arg5: verification (0: no; 1: yes)\n"); + printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg7: print tensor value (0: no; 1: yes)\n"); + printf("arg8: time kernel (0=no, 1=yes)\n"); + printf("arg9 to 14: M, N, K, StrideA, StrideB, StrideE\n"); + printf("optional:\n"); + printf("arg15: number of warm-up cycles (default 1)\n"); + printf("arg16: number of iterations (default 10)\n"); + printf("arg17: memory for rotating buffer (default 0, size in MB)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto scale_block_tile = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + + const int M = std::stoi(argv[9]); + const int N = std::stoi(argv[10]); + const int K = std::stoi(argv[11]); + + const int StrideA = std::stoi(argv[12]); + const int StrideB = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + int n_warmup = 1; + int n_iter = 10; + uint64_t rotating = 0; + if(argc == 18) + { + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); + rotating = std::stoull(argv[17]) * 1024 * 1024; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F8 = ck::f8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a0_type, + auto a1_type, + auto b0_type, + auto b1_type, + auto comp_type, + auto acc_type, + auto c_type, + auto scale_block_m, + auto scale_block_n, + auto scale_block_k, + auto a_layout, + auto b_layout, + auto e_layout) { + using A0DataType = decltype(a0_type); + using A1DataType = decltype(a1_type); + using B0DataType = decltype(b0_type); + using B1DataType = decltype(b1_type); + using ComputeDataType = decltype(comp_type); + using AccDataType = decltype(acc_type); + using EDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideE < 0) ? DefaultStrideE : StrideE, + n_warmup, + n_iter, + rotating); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN && + scale_block_tile == ScaleBlockTile::Tile_1_128_128) + { + return profile(F8{}, + F32{}, + F8{}, + F32{}, + F8{}, + F32{}, + BF16{}, + ck::Number<1>{}, + ck::Number<128>{}, + ck::Number<128>{}, + Row{}, + Col{}, + Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_blockscale_weighpreshuffle); diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 21a0484d19..4bb38a0c16 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -1225,18 +1225,18 @@ struct TestMXMFMA { case 0: a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.5f}}); + a_scales.GenerateTensorValue(GeneratorTensor_1{0.5f}); // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); - b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); + b_scales.GenerateTensorValue(GeneratorTensor_1{1.0f}); break; case 1: // results in C = {K} a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); + a_scales.GenerateTensorValue(GeneratorTensor_1{512.0f}); b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); + b_scales.GenerateTensorValue(GeneratorTensor_1{1.0f / 512}); break; case 2: // expect small round off errors