mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge remote-tracking branch 'origin/feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
This commit is contained in:
@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
|
||||
@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -383,6 +383,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(num_splits != 1)
|
||||
{
|
||||
std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl;
|
||||
num_splits = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user