From 228b1e8d8727544ab465ce3fbcaf96099eb77377 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 4 Dec 2025 12:18:25 +0800 Subject: [PATCH] fp8 fmha async pipeline (#3339) * replace qr with async pipeline * Add fp8fp32 to DTYPE_BITS * Add kAlignmentRandVal to avoid compile fail * format --------- Co-authored-by: Thomas Ning [ROCm/composable_kernel commit: eb7f6177136173c8a6af539bffd915fddff293c4] --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 23 +++++++++++++------ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 ++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 360d6a7c78..17d4f6e1d7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -29,7 +29,15 @@ from codegen.cpp_symbol_map import ( from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8": 8, + "fp8bf16": 8, + "fp8fp32": 8, + "bf8": 8, +} K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} @@ -678,6 +686,7 @@ class KernelComponentFactoryGfx9: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip elif dtype in ["fp8fp32"]: @@ -742,8 +751,8 @@ class KernelComponentFactoryGfx9: get_mask_map(mask_impl).keys(), ["no"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO None @@ -958,7 +967,7 @@ def get_fwd_blobs( cond &= mode == "batch" cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -967,7 +976,7 @@ def get_fwd_blobs( cond &= mode == "group" cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue # aiter::mha_fwd C++ api integration @@ -975,13 +984,13 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16", "fp8bf16"] cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue elif receipt == 888: cond = dtype in ["fp8bf16", "fp8fp32"] cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 or hdim == 256 + cond &= hdim == 128 or hdim == 192 if not cond: continue diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 27776453f6..2102fe768f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -87,6 +87,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v;