mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
temp save
This commit is contained in:
@@ -30,7 +30,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS}
|
||||
# --filter fmha_fwd...
|
||||
--filter fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_pddv_nlogits_nbias_nmask_nlse_nsquant_npagedkv@fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_pddv_nlogits_nbias_nmask_nlse_nsquant_npagedkv
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
|
||||
3
example/ck_tile/01_fmha/fmha_fwd.cpp
Executable file → Normal file
3
example/ck_tile/01_fmha/fmha_fwd.cpp
Executable file → Normal file
@@ -1042,7 +1042,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>> || std::is_same_v<fmha_fwd_decode_args, std::decay_t<decltype(args)>>)
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>> ||
|
||||
std::is_same_v<fmha_fwd_decode_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
|
||||
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -995,7 +995,7 @@ auto fmha_fwd_decode_create_kargs_and_grids(fmha_fwd_decode_args args)
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_acc_ptr,
|
||||
// args.o_acc_ptr,
|
||||
// args.o_acc_ptr,
|
||||
args.o_ptr, // hardcoding
|
||||
args.batch,
|
||||
args.seqlen_q,
|
||||
@@ -1625,10 +1625,7 @@ struct fmha_fwd_decode_traits
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd_decode(fmha_fwd_decode_traits,
|
||||
fmha_fwd_decode_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
float fmha_fwd_decode(fmha_fwd_decode_traits, fmha_fwd_decode_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_appendkv_traits
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user