MFMA_32x32x16 for gfx950 (#2121)

* Enable MFMA_32x32x16 for fp16/BF16 for gfx950

* clang formatted
This commit is contained in:
Khushbu Agarwal
2025-04-24 10:20:22 -07:00
committed by GitHub
parent 01cb8379cd
commit a2ed34a112

View File

@@ -20,9 +20,15 @@ using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
@@ -105,9 +111,15 @@ using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<