diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index a2c320f3e6..a7695af409 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -378,18 +378,34 @@ using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x32_fp8_bf8_CTransposed = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x32_bf8_fp8_CTransposed = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x32_bf8_bf8_CTransposed = WarpGemmImpl, + 2>>; + template using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution = WarpGemmImpl struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8_CTransposed; }; // scale mfma based f8f6f4 template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; };