From 769fbb62d563fd4e0bf35de69bf746d4d31a3bd4 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 23 Jul 2025 03:40:15 +0000 Subject: [PATCH] epilogue switched to cshuffle --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- .../01_fmha/codegen/ops/fmha_fwd_decode.py | 25 +++++++++++-- example/ck_tile/01_fmha/fmha_fwd.hpp | 1 + .../unary_element_wise_operation.hpp | 13 +++---- .../ops/epilogue/cshuffle_epilogue.hpp | 36 ++++++++----------- .../fmha/kernel/fmha_fwd_decode_kernel.hpp | 2 +- 7 files changed, 49 insertions(+), 32 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 79fed0898d..664cfc80e0 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -30,7 +30,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api ${FMHA_FWD_APIS} - # --filter @fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_nlogits_nbias_nmask_nlse_nsquant_npagedkv + # --filter @fmha_fwd_decode_d128_bf16_batch_b16x32x128x128x32x128_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_npad_nlogits_nbias_nmask_nlse_nsquant_npagedkv ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 78cec40aa8..145cd1d002 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -31,7 +31,7 @@ K0_MAX_SUBMAX_MAP = { } FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd.hpp" diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py index fcd32cb15d..d580bafe48 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py @@ -101,9 +101,30 @@ using fmha_pipeline = {F_pipeline}< /// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving /// store_tile_raw() data corruption issue using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, + ck_tile::CShuffleEpilogue::PDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::VDataType, + ck_tile::tuple<>, + typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - false, false>>; + ck_tile::tuple<>, + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::element_wise::PassThrough, + {F_rm0}*{F_rn0}*64, + {F_bm0}, + {F_bn1}, + {F_rm1}, + {F_rn1}, + {F_wm1}, + {F_wn1}, + {F_wk1}, + true, + ck_tile::integral_constant{{}}.value, + 1, + true, + 16/sizeof(typename FmhaFwdTypeConfig<{F_dtype}>::ODataType)>>; +static_assert(16/sizeof(typename FmhaFwdTypeConfig<{F_dtype}>::ODataType)==8); using fmha_kernel = ck_tile::FmhaFwdDecodeKernel; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 4b0df4bf0c..33f8d7b515 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/elementwise.hpp" #include "ck_tile/ops/fmha.hpp" #include "bias.hpp" diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index a3fe5045cf..8937f4ca8f 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -240,12 +240,13 @@ struct PassThrough y = type_convert(x); } - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const - { - y = type_convert(x); - } + // template <> + // CK_TILE_HOST_DEVICE void + // operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) + // const + // { + // y = type_convert(x); + // } template <> CK_TILE_HOST_DEVICE void operator()(float& y, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index bf58544259..109b8d3645 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -161,29 +161,23 @@ struct CShuffleEpilogue * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave */ static constexpr auto shuffle_tile_tuple = [] { - constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); - if constexpr(elem_per_thread >= GetVectorSizeC()) + constexpr index_t memory_friendly_bulk_length = 128 /sizeof(ODataType); + + if constexpr(std::is_same_v) { - return std::make_tuple(1, 1); - } - else - { - constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; - if constexpr(std::is_same_v) - { - static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && - (kMPerBlock % num_xdl_shuffles == 0), - "kMPerBlock must be divisible by MPerXdl*MWave and " - "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); + if constexpr((kNPerBlock % memory_friendly_bulk_length) !=0){ + return std::make_tuple(1, 1); } - else - { - static_assert((kNPerBlock % (NPerXdl * NWave) == 0) && - (kNPerBlock % num_xdl_shuffles == 0), - "kNPerBlock must be divisible by NPerXdl*NWave and " - "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); + else{ + return std::make_tuple(1, max(1, memory_friendly_bulk_length / (NPerXdl * NWave))); + } + } + else{ + if constexpr((kMPerBlock % memory_friendly_bulk_length) !=0){ + return std::make_tuple(1, 1); + } + else{ + return std::make_tuple(max(1, memory_friendly_bulk_length / (MPerXdl * MWave)), 1); } } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp index 3be5acedf2..f5c2f9a717 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp @@ -1116,7 +1116,7 @@ struct FmhaFwdDecodeKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_acc_dram_window, o_acc_tile); + EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, ck_tile::tuple<>{}, smem_ptr); } };