mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK] Add FP8 per-tensor quantization support for FMHA V3 pipeline (#6051) ## Motivation The existing FMHA V3 pipeline only supports fp16/bf16 data types. This PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950, enabling higher throughput for FP8 inference workloads using the assembly-optimized V3 code path. ## Technical Details **Warp GEMM:** - Add FP8 32x32x32 warp gemm with C-transposed distribution (`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):** - Add per-tensor descale support for Q, K, V tensors, passing descale pointers through to pipeline kargs **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):** - Add FP8 data path with dtype-aware type selection - Add asm volatile P matrix conversion from f32 to fp8 - Add FP8-aware instruction scheduling in `CoreLoopScheduler` **V3 Pipeline Policy (`block_fmha_fwd_v3_pipeline_default_policy.hpp`):** - Add FP8 QK warp gemm selection (SwizzleB variant for V tile distribution compatibility) **Codegen (`fmha_fwd.py`):** - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128) - Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale: no/pertensor) - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor **Misc:** - Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum (`arch.hpp`) - Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`) V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`) pending further validation. ## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950) | Problem | Regular (TFlops) | Varlen (TFlops) | |---|---:|---:| | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 | | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 | | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 | | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 | | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 | | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 | | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 | | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 | | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 | | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 | ## Test Plan Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40), GQA ratios (1:1, 1:8), and causal/non-causal modes. Verified all cases dispatch to the V3 pipeline by enabling `F_is_v3_enabled` and confirming kernel names contain `qr_async_trload_v3`. ## Test Result 176/176 tests passed with V3 enabled. All cases correctly dispatched to V3 pipeline with `pertensor` quantization. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
534 lines
24 KiB
C++
534 lines
24 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
|
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
// fp32
|
|
|
|
using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
|
|
4,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF32F32F32M16N16K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
|
|
4,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF32F32F32M32N32K4 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
|
|
4,
|
|
AttrNumAccess>>;
|
|
|
|
// tf32
|
|
// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
|
|
|
|
#if defined(CK_GFX950_SUPPORT)
|
|
// gfx950: tf32 emulated using 3x bf16 MFMA
|
|
using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl<WarpGemmAttributeMfma<
|
|
WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl<WarpGemmAttributeMfma<
|
|
WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950<WGAttrCtlEnum::Default_>>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfma<
|
|
WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaTf32Tf32F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfma<
|
|
WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
// fp16
|
|
|
|
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
1>>;
|
|
|
|
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
|
|
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
#if defined(__gfx950__)
|
|
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
|
1>>;
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
|
1>>;
|
|
#endif
|
|
|
|
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
#if defined(__gfx950__)
|
|
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
|
#else
|
|
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
#endif
|
|
|
|
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
|
|
4>>;
|
|
|
|
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
|
4>>;
|
|
|
|
// fp16 2:4 structured sparsity
|
|
using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
|
WarpGemmAttributeSmfmacImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
|
WarpGemmAttributeSmfmacImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
|
|
|
// bf16
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
|
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccessA,
|
|
AttrNumAccessB>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
|
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccessA,
|
|
AttrNumAccessB>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
|
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccessA>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
|
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
|
4,
|
|
AttrNumAccessA,
|
|
AttrNumAccessB>>;
|
|
#endif
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
1>>;
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#endif
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
#if defined(__gfx950__)
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
|
#else
|
|
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
#endif
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
|
|
4>>;
|
|
|
|
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
|
4>>;
|
|
|
|
// fp8
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
#endif
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
#endif
|
|
|
|
#if defined(__gfx950__)
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
#else
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
#endif
|
|
|
|
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x32_fp8_bf8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
AttrNumAccess>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
|
|
using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
|
2>>;
|
|
|
|
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
|
AttrNumAccess>>;
|
|
|
|
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
|
using WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<pk_fp4_t, pk_fp4_t>,
|
|
AttrNumAccess>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
template <index_t swizzle_factor = 2>
|
|
using WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
|
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
|
2,
|
|
swizzle_factor>>;
|
|
|
|
// int8
|
|
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
|
|
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
|
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
|
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
|
|
|
} // namespace ck_tile
|