mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
epilogue switched to cshuffle
This commit is contained in:
@@ -30,7 +30,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS}
|
||||
# --filter @fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_nlogits_nbias_nmask_nlse_nsquant_npagedkv
|
||||
# --filter @fmha_fwd_decode_d128_bf16_batch_b16x32x128x128x32x128_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_npad_nlogits_nbias_nmask_nlse_nsquant_npagedkv
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
|
||||
@@ -31,7 +31,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
|
||||
@@ -101,9 +101,30 @@ using fmha_pipeline = {F_pipeline}<
|
||||
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
|
||||
/// store_tile_raw() data corruption issue
|
||||
using fmha_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::VDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
false, false>>;
|
||||
ck_tile::tuple<>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
{F_rm0}*{F_rn0}*64,
|
||||
{F_bm0},
|
||||
{F_bn1},
|
||||
{F_rm1},
|
||||
{F_rn1},
|
||||
{F_wm1},
|
||||
{F_wn1},
|
||||
{F_wk1},
|
||||
true,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}}.value,
|
||||
1,
|
||||
true,
|
||||
16/sizeof(typename FmhaFwdTypeConfig<{F_dtype}>::ODataType)>>;
|
||||
static_assert(16/sizeof(typename FmhaFwdTypeConfig<{F_dtype}>::ODataType)==8);
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdDecodeKernel<fmha_pipeline, fmha_epilogue>;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
#include "bias.hpp"
|
||||
|
||||
@@ -240,12 +240,13 @@ struct PassThrough
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
// template <>
|
||||
// CK_TILE_HOST_DEVICE void
|
||||
// operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x)
|
||||
// const
|
||||
// {
|
||||
// y = type_convert<ck_tile::bf16_t>(x);
|
||||
// }
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
|
||||
|
||||
@@ -161,29 +161,23 @@ struct CShuffleEpilogue
|
||||
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
constexpr index_t memory_friendly_bulk_length = 128 /sizeof(ODataType);
|
||||
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
if constexpr((kNPerBlock % memory_friendly_bulk_length) !=0){
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
else{
|
||||
return std::make_tuple(1, max(1, memory_friendly_bulk_length / (NPerXdl * NWave)));
|
||||
}
|
||||
}
|
||||
else{
|
||||
if constexpr((kMPerBlock % memory_friendly_bulk_length) !=0){
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else{
|
||||
return std::make_tuple(max(1, memory_friendly_bulk_length / (MPerXdl * MWave)), 1);
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1116,7 +1116,7 @@ struct FmhaFwdDecodeKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_acc_dram_window, o_acc_tile);
|
||||
EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, ck_tile::tuple<>{}, smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user