add fmha dtype fp32 (#2914)

[ROCm/composable_kernel commit: 6805684788]
This commit is contained in:
Jingwei Liao
2025-09-24 15:28:39 +08:00
committed by GitHub
parent c5a3d4c765
commit d5b5e4ef95

View File

@@ -15,6 +15,10 @@
#include <utility>
#include <variant>
struct FmhaBwdFp32
{
};
struct FmhaBwdFp16
{
};
@@ -26,6 +30,26 @@ struct FmhaBwdBf16
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<FmhaBwdFp32>
{
using QDataType = float;
using KDataType = float;
using VDataType = float;
using GemmDataType = float;
using BiasDataType = float;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = float;
using OGradDataType = float;
using QGradDataType = float;
using KGradDataType = float;
using VGradDataType = float;
using BiasGradDataType = float;
};
template <>
struct FmhaBwdTypeConfig<FmhaBwdFp16>
{