From a8bf44629bb954c4c59abf805efb1293fdb4eaf5 Mon Sep 17 00:00:00 2001 From: slippedJim Date: Thu, 27 Feb 2025 19:26:19 +0800 Subject: [PATCH] make fmha bwd api template for v2 & v3 (#1918) * use template fmha_bwd function * update --------- Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: a9bcd3c98d54d0e1e44569cfd0d7a5246f31e340] --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 3 ++- example/ck_tile/01_fmha/fmha_bwd.hpp | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 17f9c64843..8082523f1b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -176,7 +176,8 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ); }} -float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ +template <> +float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 6204cbcfa8..9179dbd9be 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -452,4 +452,5 @@ struct fmha_bwd_traits bool is_deterministic; // TODO: padding check is inside this api }; +template float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);