From 92d989e35d0221e5cabface34cf2acf8812d8609 Mon Sep 17 00:00:00 2001 From: Andres Lugo <108368282+alugorey@users.noreply.github.com> Date: Tue, 11 Feb 2025 12:11:46 -0600 Subject: [PATCH] Add receipt 4 option to codegen (#1875) * Add receipt 4 option to codegen * Remove repeated code * Review comments [ROCm/composable_kernel commit: 8086bbe3a78d931eb96fe12fdc014082e18d18d3] --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 10 +++++++++- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 9 ++++++++- example/ck_tile/01_fmha/generate.py | 3 ++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 83a1e82d6d..c05660c8ab 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) @@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1c9d743f3d..ad8daba17e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue - if receipt == 2: + if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 5b1b6664cc..a0fb42aa11 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -103,7 +103,8 @@ if __name__ == "__main__": required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration" + " 2: Only generate instance for Flash attention integration\n" + \ + " 4: Only generate instance for PyTorch integration" ) args = parser.parse_args()