temp save

This commit is contained in:
aska-0096
2025-07-17 10:06:09 +00:00
parent 7e330553dc
commit 94b6430489
11 changed files with 298 additions and 325 deletions

View File

@@ -30,7 +30,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS}
# --filter fmha_fwd...
--filter fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_pddv_nlogits_nbias_nmask_nlse_nsquant_npagedkv@fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_pddv_nlogits_nbias_nmask_nlse_nsquant_npagedkv
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py

3
example/ck_tile/01_fmha/fmha_fwd.cpp Executable file → Normal file
View File

@@ -1042,7 +1042,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>> || std::is_same_v<fmha_fwd_decode_args, std::decay_t<decltype(args)>>)
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>> ||
std::is_same_v<fmha_fwd_decode_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();

View File

@@ -995,7 +995,7 @@ auto fmha_fwd_decode_create_kargs_and_grids(fmha_fwd_decode_args args)
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
// args.o_acc_ptr,
// args.o_acc_ptr,
args.o_ptr, // hardcoding
args.batch,
args.seqlen_q,
@@ -1625,10 +1625,7 @@ struct fmha_fwd_decode_traits
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd_decode(fmha_fwd_decode_traits,
fmha_fwd_decode_args,
const ck_tile::stream_config&);
float fmha_fwd_decode(fmha_fwd_decode_traits, fmha_fwd_decode_args, const ck_tile::stream_config&);
struct fmha_fwd_appendkv_traits
{