// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" namespace ck_tile { // fp32 using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl< WarpGemmAttributeMfma>>; template using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl, 4, AttrNumAccess>>; template using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution = WarpGemmImpl, 4, AttrNumAccess>>; // fp16 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAttributeMfma>>; #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; #else template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2, AttrNumAccess>>; #endif #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; #else template using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, 2, AttrNumAccess>>; #endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl, 1>>; using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl, 2>>; using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl>>; using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2, AttrNumAccess>>; #endif #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2, AttrNumAccess>>; #endif #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = WarpGemmImpl, 1>>; using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = WarpGemmImpl, 1>>; #endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution = WarpGemmImpl>>; #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>>; #else using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; #endif using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl, 4>>; using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl, 4>>; // fp16 2:4 structured sparsity using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAttributeMfma>>; #if defined(__gfx950__) template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2, AttrNumAccess>>; #endif #if defined(__gfx950__) template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2, AttrNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl, 1>>; using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl, 2>>; using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl>>; using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; #if defined(__gfx950__) template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2, AttrNumAccess>>; #endif #if defined(__gfx950__) template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2, AttrNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution = WarpGemmImpl>>; #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>>; #else using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; #endif using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl, 4>>; using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl, 4>>; // fp8 using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_16x16x32_fp8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed = WarpGemmImpl>>; using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed = WarpGemmImpl>>; using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; template using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed = WarpGemmImpl, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed = WarpGemmImpl, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed = WarpGemmImpl, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed = WarpGemmImpl, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl>>; using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl>>; using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl>>; template using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution = WarpGemmImpl, 2, swizzle_factor>>; // int8 using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed = WarpGemmImpl>>; using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl< WarpGemmAttributeMfma>>; using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed = WarpGemmImpl>>; } // namespace ck_tile