From 0c61d0da8da21dd1d2a4683fdd8babc8f405627e Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 21 Oct 2025 10:15:04 +0800 Subject: [PATCH] [CK_TILE] Add fmt: skip to FMHA codegen scripts for readability (#3057) * fmt: skip for fmha_bwd.py * more fmt: skip * thank you, copilot * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> [ROCm/composable_kernel commit: e20923f384492dab3dafdbace6f2bd2b45186cc2] --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 92 +- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 372 +------ .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 909 +----------------- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 40 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 187 +--- .../codegen/ops/fmha_pagedkv_prefill.py | 105 +- 6 files changed, 111 insertions(+), 1594 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 3b26e3ab5f..2e3f96e4a6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -575,30 +575,8 @@ class KernelComponentFactory: def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: if dtype == "fp16" or dtype == "bf16": return { - 128: [ - FmhaFwdTileSize( - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ) - ], - } + 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } # fmt: skip else: return None @@ -618,40 +596,10 @@ class KernelComponentFactory: ["t", "f"], ["t", "f"], ): - pipelines.append( - FmhaFwdPipeline( - "qr_async", - "row", - "t", - "f", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_async", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - ) - ) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip + # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip + # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip else: assert False return pipelines @@ -663,33 +611,7 @@ class CustomFactory(KernelComponentFactory): result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) if dtype == "fp16" or dtype == "bf16": if 128 in result.keys(): - result[128].insert( - 0, - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - CppConstraint( - "get_num_blocks(128) < num_cus * min_cu_util_rate" - ), - ), - ) + result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 19f5bb2288..d007b4caa3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -408,369 +408,29 @@ def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: if dtype == "fp32" and tr_load == "f": return [ # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, - FmhaBwdDQDKDVTileSize( - 32, - 128, - 32, - 32, - 32, - 32, - 64, - 32, - 32, - 1, - 4, - 1, - 4, - 1, - 1, - 2, - 2, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - 1, - ), - FmhaBwdDQDKDVTileSize( - 16, - 64, - 64, - 16, - 64, - 16, - 16, - 64, - 64, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - 1, - ), - FmhaBwdDQDKDVTileSize( - 16, - 64, - 128, - 16, - 128, - 16, - 16, - 128, - 128, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - 1, - ), - ] + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] # fmt: skip elif (dtype == "fp16" or dtype == "bf16") and tr_load == "f": return [ - FmhaBwdDQDKDVTileSize( - 32, - 128, - 32, - 32, - 32, - 32, - 64, - 32, - 32, - 1, - 4, - 1, - 4, - 1, - 1, - 2, - 2, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 1, - ), - FmhaBwdDQDKDVTileSize( - 32, - 128, - 64, - 32, - 64, - 32, - 32, - 64, - 64, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 1, - ), - FmhaBwdDQDKDVTileSize( - 32, - 128, - 96, - 32, - 96, - 32, - 32, - 96, - 96, - 1, - 4, - 1, - 4, - 1, - 1, - 2, - 2, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 1, - ), - FmhaBwdDQDKDVTileSize( - 16, - 128, - 128, - 16, - 128, - 16, - 32, - 128, - 128, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 1, - ), + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( - 16, - 64, - 256, - 16, - 256, - 16, - 32, - 256, - 256, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 1, - ), - ] + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + ] # fmt: skip elif (dtype == "fp16" or dtype == "bf16") and tr_load == "t": return [ - FmhaBwdDQDKDVTileSize( - 32, - 128, - 64, - 32, - 64, - 32, - 32, - 64, - 64, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 32, - 16, - 16, - 32, - 1, - ), - FmhaBwdDQDKDVTileSize( - 32, - 128, - 128, - 32, - 128, - 32, - 32, - 128, - 128, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 32, - 16, - 16, - 32, - 1, - ), - FmhaBwdDQDKDVTileSize( - 16, - 192, - 128, - 16, - 128, - 16, - 32, - 128, - 128, - 1, - 4, - 1, - 4, - 1, - 1, - 1, - 4, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 1, - ), + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), - FmhaBwdDQDKDVTileSize( - 32, - 16, - 64, - 32, - 64, - 32, - 16, - 64, - 64, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 2, - 32, - ), + FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32), # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), - FmhaBwdDQDKDVTileSize( - 16, - 16, - 128, - 16, - 128, - 16, - 16, - 128, - 128, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 16, - 16, - 32, - 16, - 16, - 16, - 2, - 16, - ), - ] + FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), + ] # fmt: skip else: return [] diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index cc77718c88..e5254034af 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -635,578 +635,42 @@ class KernelComponentFactory: if dtype == "fp32": return { # bm0, bn0, bk0, bn1, bk1, - (32, 32): [ - FmhaFwdTileSize( - 64, - 64, - 16, - 32, - 32, - 32, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ) - ], - (48, 48): [ - FmhaFwdTileSize( - 32, - 128, - 16, - 48, - 16, - 48, - 2, - 1, - 1, - 2, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - FmhaFwdTileSize( - 128, - 64, - 16, - 48, - 32, - 48, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - ], - (64, 64): [ - FmhaFwdTileSize( - 64, - 64, - 32, - 64, - 32, - 64, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ) - ], - (96, 128): [ - FmhaFwdTileSize( - 128, - 64, - 32, - 128, - 32, - 96, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ) - ], - (128, 128): [ - FmhaFwdTileSize( - 32, - 128, - 32, - 128, - 16, - 128, - 2, - 1, - 1, - 2, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - FmhaFwdTileSize( - 128, - 64, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - ], - (192, 192): [ - FmhaFwdTileSize( - 64, - 64, - 32, - 192, - 32, - 192, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ) - ], - (256, 256): [ - FmhaFwdTileSize( - 64, - 64, - 32, - 256, - 32, - 256, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ) - ], - } + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } # fmt: skip elif dtype == "fp16" or dtype == "bf16": return { - (32, 32): [ - FmhaFwdTileSize( - 128, - 64, - 16, - 32, - 32, - 32, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ) - ], - (64, 64): [ - FmhaFwdTileSize( - 16, - 32, - 64, - 64, - 32, - 64, - 1, - 1, - 1, - 1, - 1, - 1, - 16, - 16, - 32, - 16, - 16, - 32, - -1, - ), - FmhaFwdTileSize( - 32, - 32, - 64, - 64, - 32, - 64, - 1, - 1, - 1, - 1, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), - FmhaFwdTileSize( - 128, - 64, - 32, - 64, - 32, - 64, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), - ], - (96, 128): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 128, - 32, - 96, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ) - ], - (128, 128): [ - FmhaFwdTileSize( - 16, - 32, - 64, - 128, - 32, - 128, - 1, - 1, - 1, - 1, - 1, - 1, - 16, - 16, - 32, - 16, - 16, - 32, - -1, - ), - FmhaFwdTileSize( - 32, - 32, - 128, - 128, - 32, - 128, - 1, - 1, - 1, - 1, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), - FmhaFwdTileSize( - 128, - 64, - 32, - 128, - 16, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), - FmhaFwdTileSize( - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), - ], - # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (192, 128): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 128, - 32, - 192, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ) - ], - (192, 192): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 192, - 32, - 192, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - 1, - ) - ], - (256, 256): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 256, - 32, - 256, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ) - ], - } + ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } # fmt: skip elif dtype == "fp8" or dtype == "fp8bf16": return { - (64, 64): [ - FmhaFwdTileSize( - 128, - 64, - 32, - 64, - 32, - 64, - 2, - 1, - 1, - 2, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - (128, 128): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - (256, 256): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 256, - 32, - 256, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - } + ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } # fmt: skip elif dtype == "fp8fp32": return { - (128, 128): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - } + (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } # fmt: skip else: return None @@ -1229,60 +693,9 @@ class KernelComponentFactory: ["t", "f"], ["t", "f"], ): - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "f", - "t", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip elif dtype in ["fp16", "bf16"]: squant = "f" for logits, mask, bias, lse, dropout, skip in itertools.product( @@ -1294,137 +707,18 @@ class KernelComponentFactory: ["t", "f"], ): if hdim == 256 and hdim_v == 256: - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # the below two is used for hdim vectorize load - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip else: - pipelines.append( - FmhaFwdPipeline( - "qr_async", - "row", - "t", - "f", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_async", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip if ( (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" @@ -1433,103 +727,18 @@ class KernelComponentFactory: and lse == "f" and skip == "f" ): - pipelines.append( - FmhaFwdPipeline( - "qr_async_trload", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "t", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_async_trload", - "row", - "f", - "f", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "t", - ) - ) + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip if receipt == 1 and bias != "bias": - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline( "qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - "f", - "f", - squant, - mask, - "f", - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "f", - "f", - logits, - bias, - "f", - "f", - squant, - mask, - "f", - "f", - ) - ) + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip elif dtype in ["fp8fp16", "bf8"]: # TODO None @@ -1544,33 +753,7 @@ class CustomFactory(KernelComponentFactory): result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) if dtype == "fp16" or dtype == "bf16": if (128, 128) in result.keys(): - result[(128, 128)].insert( - 0, - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - CppConstraint( - "get_num_blocks(128) < num_cus * min_cu_util_rate" - ), - ), - ) + result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 9e107062e1..fcbf22fb18 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -349,43 +349,17 @@ def get_fwd_appendkv_blobs( # applying rotary embedding, so I just use 't' in inter/half pipelines for vlayout in ["row", "col"]: for pagedkv in ["t", "f"]: - pipelines.append( - FmhaFwdAppendKVPipeline( - vlayout, "f", "t", "f", "f", "no", pagedkv - ) - ) - pipelines.append( - FmhaFwdAppendKVPipeline( - vlayout, "t", "t", "t", "t", "no", pagedkv - ) - ) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip - pipelines.append( - FmhaFwdAppendKVPipeline( - vlayout, "f", "t", "t", "f", "inter", pagedkv - ) - ) - pipelines.append( - FmhaFwdAppendKVPipeline( - vlayout, "t", "t", "t", "t", "inter", pagedkv - ) - ) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip - pipelines.append( - FmhaFwdAppendKVPipeline( - vlayout, "f", "t", "t", "f", "half", pagedkv - ) - ) - pipelines.append( - FmhaFwdAppendKVPipeline( - vlayout, "t", "t", "t", "t", "half", pagedkv - ) - ) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip elif dtype in ["fp8", "bf8"]: # rope/paged-kv is not supported - pipelines.append( - FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f") - ) + pipelines.append(FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 9a77bc8e94..31a35ecb97 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -738,32 +738,18 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]: if dtype == "fp16" or dtype == "bf16": return { - "32": FmhaFwdTileSize( - 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1 - ), - "64": FmhaFwdTileSize( - 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1 - ), - "96": FmhaFwdTileSize( - 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1 - ), - "128": FmhaFwdTileSize( - 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1 - ), - # '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( - 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1 - ), - } + "32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } # fmt: skip elif dtype == "fp8" or dtype == "bf8": return { - "64": FmhaFwdTileSize( - 128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1 - ), - "128": FmhaFwdTileSize( - 128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1 - ), - } + "64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } # fmt: skip else: return None @@ -807,157 +793,22 @@ def get_fwd_splitkv_blobs( for logits, mask, bias, pagedkv in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"] ): - pipelines.append( - Pipeline( - "qr", - "row", - "f", - "t", - "f", - "f", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) - pipelines.append( - Pipeline( - "qr", - "col", - "f", - "t", - "f", - "f", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) + pipelines.append(Pipeline( "qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append( - Pipeline( - "qr", - "row", - "t", - "f", - "f", - "f", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) - pipelines.append( - Pipeline( - "qr", - "col", - "t", - "f", - "f", - "f", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) + pipelines.append(Pipeline( "qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append( - Pipeline( - "qr", - "row", - "t", - "t", - "f", - "f", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) - pipelines.append( - Pipeline( - "qr", - "col", - "t", - "t", - "f", - "f", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) + pipelines.append(Pipeline( "qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append( - Pipeline( - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) - pipelines.append( - Pipeline( - "qr", - "col", - "t", - "t", - "t", - "t", - logits, - bias, - "t", - squant, - pagedkv, - mask, - ) - ) + pipelines.append(Pipeline( "qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline( "qr", "col", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append( - Pipeline( - "qr", - "col", - "f", - "f", - "f", - "f", - logits, - bias, - "t", - squant, - "f", - mask, - ) - ) + pipelines.append(Pipeline( "qr", "col", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 55b0160a71..f22b0fa52f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -524,27 +524,19 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]: if dtype == "fp16" or dtype == "bf16": return { - # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - "128": FmhaFwdTileSize( - 128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1 - ), - # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - } + # "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } # fmt: skip elif dtype == "fp8" or dtype == "bf8": return { - "64": FmhaFwdTileSize( - 128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1 - ), - "128": FmhaFwdTileSize( - 128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1 - ), - "256": FmhaFwdTileSize( - 128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1 - ), - } + "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } # fmt: skip else: return None @@ -569,82 +561,17 @@ def get_fwd_blobs( ["t"], ["f"], ): - pipelines.append( - FmhaFwdPipeline( - "qr_pagedkv", - "row", - "t", - "f", - "f", - "f", - logits, - bias, - "f", - pagedkv, - squant, - mask, - skip, - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_pagedkv", - "row", - "t", - "t", - "f", - "f", - logits, - bias, - "f", - pagedkv, - squant, - mask, - skip, - ) - ) + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append( - FmhaFwdPipeline( - "qr_pagedkv", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - "f", - "t", - squant, - mask, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_pagedkv", - "row", - "t", - "t", - "f", - "f", - logits, - bias, - "f", - "t", - squant, - mask, - "f", - ) - ) + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: - # TODO - None + pass # TODO else: assert False return pipelines