epilogue switched to cshuffle

This commit is contained in:
aska-0096
2025-07-23 03:40:15 +00:00
parent 14e0ab70c6
commit 769fbb62d5
7 changed files with 49 additions and 32 deletions

View File

@@ -30,7 +30,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS}
# --filter @fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_nlogits_nbias_nmask_nlse_nsquant_npagedkv
# --filter @fmha_fwd_decode_d128_bf16_batch_b16x32x128x128x32x128_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_npad_nlogits_nbias_nmask_nlse_nsquant_npagedkv
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py

View File

@@ -31,7 +31,7 @@ K0_MAX_SUBMAX_MAP = {
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"

View File

@@ -101,9 +101,30 @@ using fmha_pipeline = {F_pipeline}<
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// store_tile_raw() data corruption issue
using fmha_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename FmhaFwdTypeConfig<{F_dtype}>::PDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::VDataType,
ck_tile::tuple<>,
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
false, false>>;
ck_tile::tuple<>,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::element_wise::PassThrough,
{F_rm0}*{F_rn0}*64,
{F_bm0},
{F_bn1},
{F_rm1},
{F_rn1},
{F_wm1},
{F_wn1},
{F_wk1},
true,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}}.value,
1,
true,
16/sizeof(typename FmhaFwdTypeConfig<{F_dtype}>::ODataType)>>;
static_assert(16/sizeof(typename FmhaFwdTypeConfig<{F_dtype}>::ODataType)==8);
using fmha_kernel =
ck_tile::FmhaFwdDecodeKernel<fmha_pipeline, fmha_epilogue>;

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"

View File

@@ -240,12 +240,13 @@ struct PassThrough
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
// template <>
// CK_TILE_HOST_DEVICE void
// operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x)
// const
// {
// y = type_convert<ck_tile::bf16_t>(x);
// }
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,

View File

@@ -161,29 +161,23 @@ struct CShuffleEpilogue
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
constexpr index_t memory_friendly_bulk_length = 128 /sizeof(ODataType);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
if constexpr((kNPerBlock % memory_friendly_bulk_length) !=0){
return std::make_tuple(1, 1);
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
else{
return std::make_tuple(1, max(1, memory_friendly_bulk_length / (NPerXdl * NWave)));
}
}
else{
if constexpr((kMPerBlock % memory_friendly_bulk_length) !=0){
return std::make_tuple(1, 1);
}
else{
return std::make_tuple(max(1, memory_friendly_bulk_length / (MPerXdl * MWave)), 1);
}
}
}();

View File

@@ -1116,7 +1116,7 @@ struct FmhaFwdDecodeKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_acc_dram_window, o_acc_tile);
EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, ck_tile::tuple<>{}, smem_ptr);
}
};