make fmha bwd api template for v2 & v3 (#1918)

* use template fmha_bwd function

* update

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>

[ROCm/composable_kernel commit: a9bcd3c98d]
This commit is contained in:
slippedJim
2025-02-27 19:26:19 +08:00
committed by GitHub
parent c5acb522de
commit ddc3ff9878
2 changed files with 3 additions and 1 deletions

View File

@@ -176,7 +176,8 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
);
}}
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
template <>
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;

View File

@@ -452,4 +452,5 @@ struct fmha_bwd_traits
bool is_deterministic;
// TODO: padding check is inside this api
};
template <int Version = 2>
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);