Files
composable_kernel/include/ck_tile/ops/gemm/warp/warp_gemm.hpp
Po Yen Chen c2ac7aa7b0 [rocm-libraries] ROCm/rocm-libraries#6051 (commit f0838b2)
[CK] Add FP8 per-tensor quantization support for FMHA V3
 pipeline (#6051)

## Motivation

The existing FMHA V3 pipeline only supports fp16/bf16 data types. This
PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950,
enabling higher throughput for
  FP8 inference workloads using the assembly-optimized V3 code path.

  ## Technical Details

  **Warp GEMM:**
- Add FP8 32x32x32 warp gemm with C-transposed distribution
(`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries

  **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):**
- Add per-tensor descale support for Q, K, V tensors, passing descale
pointers through to pipeline kargs

  **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):**
  - Add FP8 data path with dtype-aware type selection
  - Add asm volatile P matrix conversion from f32 to fp8
  - Add FP8-aware instruction scheduling in `CoreLoopScheduler`

**V3 Pipeline Policy
(`block_fmha_fwd_v3_pipeline_default_policy.hpp`):**
- Add FP8 QK warp gemm selection (SwizzleB variant for V tile
distribution compatibility)

  **Codegen (`fmha_fwd.py`):**
  - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128)
- Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale:
no/pertensor)
  - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor

  **Misc:**
- Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum
(`arch.hpp`)
- Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`)

V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`)
pending further validation.

## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950)

  | Problem | Regular (TFlops) | Varlen (TFlops) |
  |---|---:|---:|
  | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 |
  | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 |
  | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 |
  | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 |
  | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 |
  | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 |
  | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 |
  | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 |
  | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 |
  | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 |

## Test Plan

Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering
batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40),
GQA ratios (1:1, 1:8), and
causal/non-causal modes. Verified all cases dispatch to the V3 pipeline
by enabling `F_is_v3_enabled` and confirming kernel names contain
`qr_async_trload_v3`.

  ## Test Result

176/176 tests passed with V3 enabled. All cases correctly dispatched to
V3 pipeline with `pertensor` quantization.

  ## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-07 14:20:43 +00:00

534 lines
24 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
namespace ck_tile {
// fp32
using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M32N32K4 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
// tf32
// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
#if defined(CK_GFX950_SUPPORT)
// gfx950: tf32 emulated using 3x bf16 MFMA
using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaTf32Tf32F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#endif
// fp16
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
1>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp16 2:4 structured sparsity
using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
WarpGemmAttributeSmfmacImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
WarpGemmAttributeSmfmacImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
// bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccessA,
AttrNumAccessB>>;
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
2,
AttrNumAccessA,
AttrNumAccessB>>;
#else
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccessA>>;
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
4,
AttrNumAccessA,
AttrNumAccessB>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
// fp8
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
#endif
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<pk_fp4_t, pk_fp4_t>,
AttrNumAccess>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;
// int8
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
} // namespace ck_tile