Add missing kernel arguments for group mode

This commit is contained in:
PoYen, Chen
2024-08-06 14:54:07 +00:00
parent db31475e07
commit b98985262d
2 changed files with 6 additions and 3 deletions

View File

@@ -108,8 +108,8 @@ template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0
&& a.seqlen_k % (a.num_splits * {F_bk1}) == 0) {{
// make sure F_bn0 is divisible by F_bk1
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
}} else {{
kernel_runner<true>::run(s, a);
@@ -527,7 +527,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']:
# splitkv kernel donot support dropout
for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]):
if hdim == 256 or hdim == 32 or hdim == 64 or hdim == 128:
# TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))

View File

@@ -336,6 +336,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,