From e868ffa3904820e950d785162520125269a77116 Mon Sep 17 00:00:00 2001 From: Jingwei Liao Date: Wed, 24 Sep 2025 15:28:39 +0800 Subject: [PATCH] add fmha dtype fp32 (#2914) [ROCm/composable_kernel commit: 68056847887d7479a6055db6579739f555348c69] --- example/ck_tile/01_fmha/fmha_bwd.hpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index f1f8eee5e4..378ff9c9f8 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -15,6 +15,10 @@ #include #include +struct FmhaBwdFp32 +{ +}; + struct FmhaBwdFp16 { }; @@ -26,6 +30,26 @@ struct FmhaBwdBf16 template struct FmhaBwdTypeConfig; +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using GemmDataType = float; + using BiasDataType = float; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = float; + using OGradDataType = float; + using QGradDataType = float; + using KGradDataType = float; + using VGradDataType = float; + using BiasGradDataType = float; +}; + template <> struct FmhaBwdTypeConfig {