From 86591de476dbf8d16a3dce849fda0d2887220ffa Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Sun, 17 May 2026 07:30:33 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#5260 (commit a1834d2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher (#5260) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The CK Tile dispatcher currently supports GEMM and Grouped Convolution but has no support for Fused Multi-Head Attention (FMHA). The example/ck_tile/01_fmha folder contains a comprehensive FMHA implementation with forward, backward, split-KV, paged-KV, append-KV, and batch-prefill kernels across multiple GPU architectures — but there is no unified dispatch layer for it. This PR ports the FMHA stack into the dispatcher, following the same architectural patterns established by GEMM and Grouped Convolution, enabling runtime kernel selection, JIT compilation from Python, and a declarative C++ example flow. Autotuning heuristics to follow. ## Technical Details This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring GEMM's layered architecture. Seven new C++ runtime headers provide type definitions (coexisting with upstream headers via __has_include, requiring zero modifications to example/ck_tile/01_fmha/), a problem builder with 18+ setters, Signature + Algorithm kernel key matching, a virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard support and named tile/wave/warp setters, arch-aware registry with JSON export, and a dispatcher with seqtune-aware selection, configurable timing, and multi-stage execution plans for split-KV (two-stage) and backward (three-stage). The codegen pipeline is driven by a fmha_arch_specs.json capturing per-arch tile tables and pipeline constraints for five architectures (gfx90a/942/950/1100/1201), migrated from hardcoded logic in 01_fmha/codegen/, with supporting modules for C++ symbol mappings, validation rules, and named receipt profiles (ck_default, flash, pytorch, aiter, fp32, fp8). Python integration (fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel multi-kernel builds, HIP memory management via ctypes, tolerance-based validation, and a NumPy CPU reference with GQA support. Twenty-seven C++ and thirty-two Python examples cover the full feature surface — forward, split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill, fp8, logits soft cap, sink tokens, and parameter sweeps — all JIT-compiled on the fly. ## Test Plan Seven test files cover the runtime types, codegen, and end-to-end correctness. C++ unit tests validate the problem builder, dispatcher planning (single-stage for forward/paged-KV/append-KV; multi-stage for split-KV and backward), registry operations, and the kernel-set declaration macro. Python unit tests verify codegen emission, profile filtering, and 15 validation rules for masks, hdim constraints, and pipeline requirements. GPU execution validation in 01_basic_fmha --validate reports zero errors across 65,536 elements with max absolute error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py) runs 14 configurations through both the upstream tile_example_fmha_fwd and the dispatcher, comparing exit codes to confirm behavioral parity — all 14 match. ## Test Result The C++ smoke test builds and passes all 9 compiled examples, and a Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching up to 375 TFLOPS at seqlen 2048. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- dispatcher/CMakeLists.txt | 12 +- dispatcher/README.md | 38 +- dispatcher/bindings/README.md | 1 + .../bindings/ctypes/fmha_ctypes_lib.cpp | 1685 +++++++++++ dispatcher/codegen/arch_filter.py | 2 +- dispatcher/codegen/codegen_common.py | 2 +- dispatcher/codegen/fmha/__init__.py | 4 + dispatcher/codegen/fmha/codegen.py | 1385 +++++++++ dispatcher/codegen/fmha/fmha_arch_specs.json | 175 ++ dispatcher/codegen/fmha/generate_fallback.py | 261 ++ dispatcher/codegen/fmha/instance_gen.py | 2692 +++++++++++++++++ dispatcher/codegen/fmha/symbol_map.py | 333 ++ dispatcher/codegen/fmha/validation.py | 921 ++++++ dispatcher/examples/CMakeLists.txt | 102 +- dispatcher/examples/README.md | 10 +- .../examples/fmha/cpp/01_basic_fmha.cpp | 371 +++ .../examples/fmha/cpp/02_splitkv_fmha.cpp | 162 + .../examples/fmha/cpp/03_kvcache_fmha.cpp | 240 ++ dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp | 154 + .../examples/fmha/cpp/05_appendkv_fmha.cpp | 106 + .../fmha/cpp/06_batch_prefill_fmha.cpp | 133 + .../fmha/cpp/07_profile_pytorch_fmha.cpp | 248 ++ .../fmha/cpp/08_profile_flash_fmha.cpp | 165 + .../fmha/cpp/09_profile_aiter_fmha.cpp | 212 ++ .../fmha/cpp/10_profile_fp32_fp8_fmha.cpp | 152 + .../fmha/cpp/11_receipt_aliases_fmha.cpp | 176 ++ .../fmha/cpp/12_registry_json_fmha.cpp | 129 + .../fmha/cpp/13_feature_coverage_fmha.cpp | 499 +++ .../fmha/cpp/14_benchmark_validation_fmha.cpp | 404 +++ .../examples/fmha/cpp/15_multi_shape_fmha.cpp | 282 ++ .../examples/fmha/cpp/16_heuristics_fmha.cpp | 428 +++ .../fmha/cpp/17_autofill_autocorrect_fmha.cpp | 423 +++ .../examples/fmha/cpp/18_gpu_splitkv_fmha.cpp | 466 +++ .../examples/fmha/cpp/19_gpu_masks_fmha.cpp | 456 +++ .../examples/fmha/cpp/20_gpu_bias_fmha.cpp | 584 ++++ .../fmha/cpp/21_gpu_features_fmha.cpp | 697 +++++ .../examples/fmha/cpp/22_gpu_bwd_fmha.cpp | 553 ++++ .../fmha/cpp/23_multi_registry_fmha.cpp | 595 ++++ .../cpp/24_per_receipt_registries_fmha.cpp | 549 ++++ .../cpp/25_gpu_appendkv_batchprefill_fmha.cpp | 530 ++++ .../fmha/cpp/26_dtypes_hdims_fmha.cpp | 526 ++++ .../fmha/cpp/27_padding_permutation_fmha.cpp | 635 ++++ .../examples/fmha/cpp/28_bwd_masks_fmha.cpp | 489 +++ .../fmha/cpp/29_bwd_bias_dropout_fmha.cpp | 615 ++++ .../fmha/cpp/30_bwd_benchmark_fmha.cpp | 449 +++ .../fmha/cpp/31_logits_soft_cap_fmha.cpp | 118 + .../examples/fmha/cpp/32_sink_tokens_fmha.cpp | 119 + .../fmha/cpp/33_bwd_deterministic_fmha.cpp | 256 ++ .../examples/fmha/cpp/34_bwd_gqa_fmha.cpp | 183 ++ .../fmha/cpp/35_generic_mask_fmha.cpp | 121 + .../examples/fmha/python/01_basic_fmha.py | 259 ++ .../examples/fmha/python/02_multi_shape.py | 148 + .../examples/fmha/python/03_benchmark.py | 170 ++ .../examples/fmha/python/04_validation.py | 176 ++ .../fmha/python/05_numpy_integration.py | 219 ++ .../examples/fmha/python/06_json_export.py | 220 ++ .../examples/fmha/python/07_stress_test.py | 256 ++ .../examples/fmha/python/08_heuristics.py | 348 +++ .../examples/fmha/python/09_multi_registry.py | 298 ++ .../fmha/python/10_advanced_benchmark.py | 262 ++ .../examples/fmha/python/11_bf16_fmha.py | 188 ++ .../examples/fmha/python/12_masks_fmha.py | 239 ++ .../examples/fmha/python/13_bias_fmha.py | 235 ++ .../examples/fmha/python/14_dropout_fmha.py | 245 ++ .../examples/fmha/python/15_gqa_fmha.py | 217 ++ .../examples/fmha/python/16_splitkv_fmha.py | 267 ++ .../examples/fmha/python/17_appendkv_fmha.py | 362 +++ .../examples/fmha/python/18_backward_fmha.py | 299 ++ .../examples/fmha/python/19_padding_fmha.py | 344 +++ .../examples/fmha/python/20_fp8_fmha.py | 120 + .../fmha/python/21_logits_soft_cap_fmha.py | 235 ++ .../fmha/python/22_sink_tokens_fmha.py | 315 ++ .../fmha/python/23_batch_prefill_fmha.py | 406 +++ .../fmha/python/24_vlayout_col_fmha.py | 250 ++ .../fmha/python/25_permutation_fmha.py | 262 ++ .../fmha/python/26_hdim_variety_fmha.py | 268 ++ .../fmha/python/27_backward_dropout_fmha.py | 373 +++ .../fmha/python/28_backward_dbias_fmha.py | 360 +++ .../examples/fmha/python/29_sweep_seqlen.py | 146 + .../examples/fmha/python/30_sweep_batch.py | 151 + .../examples/fmha/python/31_sweep_nhead.py | 172 ++ .../examples/fmha/python/32_sweep_hdim.py | 178 ++ .../examples/fmha/python/33_bwd_masks_fmha.py | 271 ++ .../examples/fmha/python/34_bwd_gqa_fmha.py | 277 ++ .../examples/fmha/python/35_bwd_bf16_fmha.py | 270 ++ .../fmha/python/36_bwd_benchmark_fmha.py | 234 ++ .../fmha/python/37_bwd_deterministic_fmha.py | 316 ++ .../fmha/python/38_bwd_sweep_hdim_fmha.py | 244 ++ .../gemm/python/05_numpy_integration.py | 2 - .../examples/gemm/python/06_json_export.py | 2 - .../examples/gemm/python/07_stress_test.py | 3 - .../examples/gemm/python/08_heuristics.py | 2 - .../examples/gemm/python/09_multi_registry.py | 2 - .../gemm/python/10_advanced_benchmark.py | 2 - .../examples/gemm/python/11_json_import.py | 3 - dispatcher/include/ck_tile/dispatcher.hpp | 11 + .../backends/generated_fmha_backend.hpp | 266 ++ .../backends/generated_kernel_backend.hpp | 10 +- .../backends/generated_tile_backend.hpp | 12 +- .../ck_tile/dispatcher/example_args.hpp | 5 +- .../ck_tile/dispatcher/fmha_dispatcher.hpp | 105 + .../ck_tile/dispatcher/fmha_kernel_decl.hpp | 646 ++++ .../dispatcher/fmha_kernel_instance.hpp | 45 + .../ck_tile/dispatcher/fmha_kernel_key.hpp | 216 ++ .../ck_tile/dispatcher/fmha_problem.hpp | 794 +++++ .../ck_tile/dispatcher/fmha_registry.hpp | 63 + .../include/ck_tile/dispatcher/fmha_types.hpp | 605 ++++ .../ck_tile/dispatcher/kernel_instance.hpp | 9 + .../include/ck_tile/dispatcher_fmha.hpp | 17 + .../include/ck_tile/dispatcher_gemm.hpp | 25 +- dispatcher/python/dispatcher_common.py | 16 + dispatcher/python/fmha_utils.py | 1842 +++++++++++ dispatcher/scripts/example_kernel_builder.py | 408 ++- dispatcher/scripts/parallel_kernel_builder.py | 34 +- dispatcher/src/dispatcher.cpp | 2 + dispatcher/src/fmha_dispatcher.cpp | 369 +++ dispatcher/src/fmha_registry.cpp | 302 ++ dispatcher/tests/CMakeLists.txt | 41 + dispatcher/tests/fmha_smoke_matrix.py | 416 +++ dispatcher/tests/full_parity_test.py | 1020 +++++++ .../tests/smoke_test_fmha_dispatcher.sh | 91 + dispatcher/tests/test_fmha_codegen.py | 172 ++ dispatcher/tests/test_fmha_dispatcher.cpp | 491 +++ dispatcher/tests/test_fmha_kernel_decl.cpp | 38 + dispatcher/tests/test_fmha_parity.py | 219 ++ dispatcher/tests/test_fmha_problem.cpp | 144 + dispatcher/tests/test_fmha_registry.cpp | 124 + dispatcher/tests/test_fmha_rules.py | 168 + tile_engine/CMakeLists.txt | 1 + tile_engine/operation_support_matrix.md | 2 +- tile_engine/ops/common/parallel_runner.py | 63 + tile_engine/ops/fmha/.gitignore | 3 + tile_engine/ops/fmha/CMakeLists.txt | 94 + tile_engine/ops/fmha/README.md | 192 ++ .../ops/fmha/ck_fmha_testing_matrix.yaml | 788 +++++ tile_engine/ops/fmha/configs/appendkv.json | 6 + .../ops/fmha/configs/batch_prefill.json | 6 + tile_engine/ops/fmha/configs/bwd.json | 6 + tile_engine/ops/fmha/configs/fwd.json | 9 + tile_engine/ops/fmha/configs/fwd_ci.json | 14 + tile_engine/ops/fmha/configs/pagedkv.json | 6 + .../ops/fmha/configs/receipt0_fwd.json | 6 + tile_engine/ops/fmha/configs/splitkv.json | 6 + .../ops/fmha/filters/h128_no_dropout.py | 14 + tile_engine/ops/fmha/fmha_benchmark.py | 939 ++++++ tile_engine/ops/fmha/fmha_full_benchmark.py | 689 +++++ tile_engine/ops/fmha/run_full_sweep.py | 175 ++ tile_engine/ops/fmha/run_one_kernel.py | 128 + 148 files changed, 41250 insertions(+), 87 deletions(-) create mode 100644 dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp create mode 100644 dispatcher/codegen/fmha/__init__.py create mode 100644 dispatcher/codegen/fmha/codegen.py create mode 100644 dispatcher/codegen/fmha/fmha_arch_specs.json create mode 100644 dispatcher/codegen/fmha/generate_fallback.py create mode 100644 dispatcher/codegen/fmha/instance_gen.py create mode 100644 dispatcher/codegen/fmha/symbol_map.py create mode 100644 dispatcher/codegen/fmha/validation.py create mode 100644 dispatcher/examples/fmha/cpp/01_basic_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/13_feature_coverage_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/14_benchmark_validation_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/16_heuristics_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/17_autofill_autocorrect_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/18_gpu_splitkv_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/19_gpu_masks_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/20_gpu_bias_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/21_gpu_features_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/22_gpu_bwd_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/23_multi_registry_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/24_per_receipt_registries_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/25_gpu_appendkv_batchprefill_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/26_dtypes_hdims_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/27_padding_permutation_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/28_bwd_masks_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/29_bwd_bias_dropout_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/30_bwd_benchmark_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/31_logits_soft_cap_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/32_sink_tokens_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/33_bwd_deterministic_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/34_bwd_gqa_fmha.cpp create mode 100644 dispatcher/examples/fmha/cpp/35_generic_mask_fmha.cpp create mode 100644 dispatcher/examples/fmha/python/01_basic_fmha.py create mode 100644 dispatcher/examples/fmha/python/02_multi_shape.py create mode 100644 dispatcher/examples/fmha/python/03_benchmark.py create mode 100644 dispatcher/examples/fmha/python/04_validation.py create mode 100644 dispatcher/examples/fmha/python/05_numpy_integration.py create mode 100644 dispatcher/examples/fmha/python/06_json_export.py create mode 100644 dispatcher/examples/fmha/python/07_stress_test.py create mode 100644 dispatcher/examples/fmha/python/08_heuristics.py create mode 100644 dispatcher/examples/fmha/python/09_multi_registry.py create mode 100644 dispatcher/examples/fmha/python/10_advanced_benchmark.py create mode 100644 dispatcher/examples/fmha/python/11_bf16_fmha.py create mode 100644 dispatcher/examples/fmha/python/12_masks_fmha.py create mode 100644 dispatcher/examples/fmha/python/13_bias_fmha.py create mode 100644 dispatcher/examples/fmha/python/14_dropout_fmha.py create mode 100644 dispatcher/examples/fmha/python/15_gqa_fmha.py create mode 100644 dispatcher/examples/fmha/python/16_splitkv_fmha.py create mode 100644 dispatcher/examples/fmha/python/17_appendkv_fmha.py create mode 100644 dispatcher/examples/fmha/python/18_backward_fmha.py create mode 100644 dispatcher/examples/fmha/python/19_padding_fmha.py create mode 100644 dispatcher/examples/fmha/python/20_fp8_fmha.py create mode 100644 dispatcher/examples/fmha/python/21_logits_soft_cap_fmha.py create mode 100644 dispatcher/examples/fmha/python/22_sink_tokens_fmha.py create mode 100644 dispatcher/examples/fmha/python/23_batch_prefill_fmha.py create mode 100644 dispatcher/examples/fmha/python/24_vlayout_col_fmha.py create mode 100644 dispatcher/examples/fmha/python/25_permutation_fmha.py create mode 100644 dispatcher/examples/fmha/python/26_hdim_variety_fmha.py create mode 100644 dispatcher/examples/fmha/python/27_backward_dropout_fmha.py create mode 100644 dispatcher/examples/fmha/python/28_backward_dbias_fmha.py create mode 100644 dispatcher/examples/fmha/python/29_sweep_seqlen.py create mode 100644 dispatcher/examples/fmha/python/30_sweep_batch.py create mode 100644 dispatcher/examples/fmha/python/31_sweep_nhead.py create mode 100644 dispatcher/examples/fmha/python/32_sweep_hdim.py create mode 100644 dispatcher/examples/fmha/python/33_bwd_masks_fmha.py create mode 100644 dispatcher/examples/fmha/python/34_bwd_gqa_fmha.py create mode 100644 dispatcher/examples/fmha/python/35_bwd_bf16_fmha.py create mode 100644 dispatcher/examples/fmha/python/36_bwd_benchmark_fmha.py create mode 100644 dispatcher/examples/fmha/python/37_bwd_deterministic_fmha.py create mode 100644 dispatcher/examples/fmha/python/38_bwd_sweep_hdim_fmha.py create mode 100644 dispatcher/include/ck_tile/dispatcher/backends/generated_fmha_backend.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/fmha_dispatcher.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/fmha_kernel_decl.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/fmha_kernel_instance.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/fmha_kernel_key.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/fmha_problem.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/fmha_registry.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/fmha_types.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher_fmha.hpp create mode 100644 dispatcher/python/fmha_utils.py create mode 100644 dispatcher/src/fmha_dispatcher.cpp create mode 100644 dispatcher/src/fmha_registry.cpp create mode 100644 dispatcher/tests/fmha_smoke_matrix.py create mode 100644 dispatcher/tests/full_parity_test.py create mode 100755 dispatcher/tests/smoke_test_fmha_dispatcher.sh create mode 100644 dispatcher/tests/test_fmha_codegen.py create mode 100644 dispatcher/tests/test_fmha_dispatcher.cpp create mode 100644 dispatcher/tests/test_fmha_kernel_decl.cpp create mode 100644 dispatcher/tests/test_fmha_parity.py create mode 100644 dispatcher/tests/test_fmha_problem.cpp create mode 100644 dispatcher/tests/test_fmha_registry.cpp create mode 100644 dispatcher/tests/test_fmha_rules.py create mode 100644 tile_engine/ops/common/parallel_runner.py create mode 100644 tile_engine/ops/fmha/.gitignore create mode 100644 tile_engine/ops/fmha/CMakeLists.txt create mode 100644 tile_engine/ops/fmha/README.md create mode 100644 tile_engine/ops/fmha/ck_fmha_testing_matrix.yaml create mode 100644 tile_engine/ops/fmha/configs/appendkv.json create mode 100644 tile_engine/ops/fmha/configs/batch_prefill.json create mode 100644 tile_engine/ops/fmha/configs/bwd.json create mode 100644 tile_engine/ops/fmha/configs/fwd.json create mode 100644 tile_engine/ops/fmha/configs/fwd_ci.json create mode 100644 tile_engine/ops/fmha/configs/pagedkv.json create mode 100644 tile_engine/ops/fmha/configs/receipt0_fwd.json create mode 100644 tile_engine/ops/fmha/configs/splitkv.json create mode 100644 tile_engine/ops/fmha/filters/h128_no_dropout.py create mode 100644 tile_engine/ops/fmha/fmha_benchmark.py create mode 100644 tile_engine/ops/fmha/fmha_full_benchmark.py create mode 100644 tile_engine/ops/fmha/run_full_sweep.py create mode 100644 tile_engine/ops/fmha/run_one_kernel.py diff --git a/dispatcher/CMakeLists.txt b/dispatcher/CMakeLists.txt index 2acc73d1d5..ed9b20d33c 100644 --- a/dispatcher/CMakeLists.txt +++ b/dispatcher/CMakeLists.txt @@ -21,6 +21,8 @@ endif() add_library(ck_tile_dispatcher src/registry.cpp src/dispatcher.cpp + src/fmha_registry.cpp + src/fmha_dispatcher.cpp ) # Enable PIC for Python bindings @@ -34,13 +36,21 @@ target_include_directories(ck_tile_dispatcher $ ) -# Link against CK Tile headers (header-only) +# CK Tile core headers (ck_tile/core, ck_tile/ops, etc.) target_include_directories(ck_tile_dispatcher PUBLIC $ $ ) +# CK project root -- needed only for FMHA generated wrappers that include +# "example/ck_tile/01_fmha/fmha_fwd.hpp". PRIVATE to avoid exposing the +# entire project tree to downstream consumers. +target_include_directories(ck_tile_dispatcher + PRIVATE + $ +) + # Link against HIP headers if available if(hip_FOUND) target_link_libraries(ck_tile_dispatcher PUBLIC hip::host) diff --git a/dispatcher/README.md b/dispatcher/README.md index dc864f7c62..307e612305 100644 --- a/dispatcher/README.md +++ b/dispatcher/README.md @@ -394,6 +394,12 @@ python3 examples/grouped_conv/python/03_bwd_data.py # Backward data + python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON + +# FMHA Examples (JIT-compiled on the fly) +python3 examples/fmha/python/01_basic_fmha.py # Basic forward attention +python3 examples/fmha/python/12_masks_fmha.py # Causal masks +python3 examples/fmha/python/18_backward_fmha.py # Backward pass +python3 examples/fmha/python/16_splitkv_fmha.py # Split-KV for long sequences ``` ### Example Output @@ -716,7 +722,7 @@ This matrix shows all CK Tile operations with per-data-type, per-layout, and per | GEMM | streamk_gemm
example: `40_streamk_gemm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | Reduce | multi_reduce2d
example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | | Reduce | reduce2d
example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | -| Attention | fmha
example: `01_fmha/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Attention | fmha
example: `01_fmha/` | ✅ | ✅ | ✅ | ✅ | ❌ | | | | | | | ✅ | ✅ | ✅ | ❌ | | Attention | sparse_attn
example: `50_sparse_attn/` | ❌ | | ❌ | | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | | Activation | softmax | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | | Activation | topk_softmax
example: `09_topk_softmax/` | ❌ | ❌ | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | @@ -871,7 +877,14 @@ dispatcher/ | |---- grouped_conv_problem.hpp # Grouped conv problem (with builder) | |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations | |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe) -| +---- grouped_conv_utils.hpp # Grouped conv utilities +| |---- grouped_conv_utils.hpp # Grouped conv utilities +| |---- fmha_types.hpp # FMHA fwd/bwd args and traits structs +| |---- fmha_problem.hpp # FmhaProblem, FmhaProblemBuilder +| |---- fmha_kernel_key.hpp # FmhaKernelKey (Signature + Algorithm) +| |---- fmha_kernel_instance.hpp # FmhaKernelInstance virtual interface +| |---- fmha_kernel_decl.hpp # Declarative FmhaSignature/FmhaAlgorithm +| |---- fmha_registry.hpp # FmhaRegistry (thread-safe) +| +---- fmha_dispatcher.hpp # FmhaDispatcher (plan, select, run) | |---- src/ # C++ implementation | @@ -879,12 +892,17 @@ dispatcher/ | |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings | |---- unified_gemm_codegen.py # GEMM kernel generator | |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator +| |---- unified_fmha_codegen.py # FMHA kernel generator +| |---- fmha_arch_specs.json # FMHA per-arch tile/pipeline specs +| |---- fmha_rules.py # FMHA validation rules +| |---- fmha_profiles.py # FMHA named profiles/receipts | +---- arch_specs.json # GPU specifications | |---- python/ # Python utilities | |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output | |---- ctypes_utils.py # GEMM ctypes utilities -| +---- grouped_conv_utils.py # Grouped conv utilities +| |---- grouped_conv_utils.py # Grouped conv utilities +| +---- fmha_utils.py # FMHA: JIT compile, FmhaRunner, FmhaKernelConfig | |---- scripts/ # Build scripts | |---- compile_gemm_examples.py # GEMM build script @@ -892,15 +910,19 @@ dispatcher/ | |---- bindings/ctypes/ # Python ctypes interface | |---- gemm_ctypes_lib.cpp # GEMM Python library -| +---- conv_ctypes_lib.cpp # Grouped conv Python library +| |---- conv_ctypes_lib.cpp # Grouped conv Python library +| +---- fmha_ctypes_lib.cpp # FMHA Python library | |---- examples/ # Examples | |---- gemm/ | | |---- cpp/ # C++ GEMM examples (01-07) | | +---- python/ # Python GEMM examples (01-11) -| +---- grouped_conv/ -| |---- cpp/ # C++ Grouped Conv examples (01-07) -| +---- python/ # Python Grouped Conv examples (01-06) +| |---- grouped_conv/ +| | |---- cpp/ # C++ Grouped Conv examples (01-07) +| | +---- python/ # Python Grouped Conv examples (01-06) +| +---- fmha/ +| |---- cpp/ # C++ FMHA examples (01-35) +| +---- python/ # Python FMHA examples (01-38) | +---- tests/ # Unit tests (C++ and Python) ``` @@ -913,6 +935,8 @@ dispatcher/ |-----------|--------| | GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) | | GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) | +| FMHA C++ | examples/fmha/cpp/ (35 examples covering all FMHA variants) | +| FMHA Python | examples/fmha/python/ (38 examples with JIT compilation) | | Codegen | [codegen/README.md](codegen/README.md) | | Python Utils | [python/README.md](python/README.md) | | C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) | diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md index 04029d32a9..e460b38b5b 100644 --- a/dispatcher/bindings/README.md +++ b/dispatcher/bindings/README.md @@ -10,6 +10,7 @@ bindings/ | |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API | |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data) | |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library) +| |---- fmha_ctypes_lib.cpp # FMHA dispatcher C API (fwd + bwd) | |---- gpu_helper.cpp # CLI helper for Python | +---- CMakeLists.txt +---- README.md diff --git a/dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp b/dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp new file mode 100644 index 0000000000..43dbb571d8 --- /dev/null +++ b/dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp @@ -0,0 +1,1685 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// FMHA Dispatcher ctypes library. +// Provides a C API for Python ctypes integration. +// Kernel header included via -include at compile time. +// +// Thread safety: NOT thread-safe. Python ctypes releases the GIL during +// foreign calls, so single-threaded usage must be enforced by the caller. + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" + +#ifndef GFX_ARCH +#error "GFX_ARCH must be defined at compile time (e.g. -DGFX_ARCH=\"gfx950\")" +#endif + +using namespace ck_tile::dispatcher; + +static std::unique_ptr g_registry; +static std::unique_ptr g_dispatcher; +static bool g_initialized = false; + +#define HIP_CHECK(call) \ + do \ + { \ + hipError_t err_ = (call); \ + if(err_ != hipSuccess) \ + { \ + rc = -1; \ + goto cleanup; \ + } \ + } while(0) + +static inline void safe_hip_free(void*& ptr) +{ + if(ptr) + { + hipFree(ptr); + ptr = nullptr; + } +} + +static int dtype_input_bytes(const char* dtype) +{ + if(!dtype) + return 2; + if(std::strcmp(dtype, "fp32") == 0) + return 4; + if(std::strcmp(dtype, "fp8bf16") == 0 || std::strcmp(dtype, "fp8fp32") == 0 || + std::strcmp(dtype, "bf8") == 0 || std::strcmp(dtype, "fp8") == 0) + return 1; + return 2; // fp16, bf16 +} + +static int dtype_output_bytes(const char* dtype) +{ + if(!dtype) + return 2; + if(std::strcmp(dtype, "fp32") == 0 || std::strcmp(dtype, "fp8fp32") == 0) + return 4; + if(std::strcmp(dtype, "fp8") == 0 || std::strcmp(dtype, "bf8") == 0) + return 1; + return 2; // fp16, bf16, fp8bf16 (output is bf16) +} + +// Run the single registered kernel directly, bypassing the multi-stage plan() +// that requires split+combine for splitkv or dot+dq+convert for bwd. +// Used for single-kernel .so benchmarking. +static float run_single_kernel(const FmhaInvocation& invocation) +{ + auto kernels = g_registry->get_all(); + if(kernels.empty()) + { + throw std::runtime_error("No FMHA kernels registered"); + } + ck_tile::stream_config sc; + sc.log_level_ = 0; + if(g_dispatcher) + { + sc.time_kernel_ = true; + sc.cold_niters_ = 10; + sc.nrepeat_ = 50; + } + return kernels.front()->run(invocation, sc); +} + +extern "C" { + +int fmha_dispatcher_initialize(const char* arch) +{ + if(g_initialized) + return 0; + + const std::string gfx_arch = arch ? arch : GFX_ARCH; + + g_registry = std::make_unique(); + g_registry->set_name("fmha_ctypes"); + REGISTER_GENERATED_KERNELS(*g_registry, gfx_arch); + + if(g_registry->size() == 0) + return -1; + + g_dispatcher = std::make_unique(g_registry.get()); + g_dispatcher->set_benchmarking(true); + g_dispatcher->set_timing(1, 3); + g_initialized = true; + return 0; +} + +int fmha_dispatcher_run_fwd(const void* q_host, + const void* k_host, + const void* v_host, + void* o_host, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + int mask_type_int, + int bias_type_int, + int has_lse, + int has_dropout, + int traits_hdim_q, + int traits_hdim_v, + int is_v_rowmajor, + int perm, + const char* data_type_str, + int is_group_mode, + int window_left, + int window_right, + int has_logits, + int has_sink, + int has_skip, + float* time_ms_out) +{ + if(!g_initialized) + return -1; + + const int in_bytes = dtype_input_bytes(data_type_str); + const int out_bytes = dtype_output_bytes(data_type_str); + + int rc = 0; + const int64_t q_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_q * in_bytes; + const int64_t k_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_q * in_bytes; + const int64_t v_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_v * in_bytes; + const int64_t o_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_v * out_bytes; + const int64_t bias_bytes = + static_cast(batch) * nhead_q * seqlen_q * seqlen_k * out_bytes; + const int64_t lse_bytes = static_cast(batch) * nhead_q * seqlen_q * sizeof(float); + float elapsed = 0.0f; + + void *q_dev = nullptr, *k_dev = nullptr, *v_dev = nullptr, *o_dev = nullptr; + void *bias_dev = nullptr, *lse_dev_buf = nullptr, *sink_dev_fwd = nullptr; + void *seqstart_q_dev = nullptr, *seqstart_k_dev = nullptr, *seqlen_k_dev = nullptr; + + fmha_fwd_traits traits{}; + traits.hdim_q = (traits_hdim_q > 0) ? traits_hdim_q : hdim_q; + traits.hdim_v = (traits_hdim_v > 0) ? traits_hdim_v : hdim_v; + traits.data_type = data_type_str ? data_type_str : "fp16"; + traits.is_group_mode = (is_group_mode != 0); + traits.is_v_rowmajor = (is_v_rowmajor != 0); + traits.mask_type = static_cast(mask_type_int); + traits.bias_type = static_cast(bias_type_int); + traits.has_lse = (has_lse != 0); + traits.has_dropout = (has_dropout != 0); + traits.qscale_type = quant_scale_enum::no_scale; + traits.has_logits_soft_cap = (has_logits != 0); + traits.skip_min_seqlen_q = (has_skip != 0); + traits.has_sink = (has_sink != 0); + + fmha_fwd_args args{}; + + HIP_CHECK(hipMalloc(&q_dev, q_bytes)); + HIP_CHECK(hipMalloc(&k_dev, k_bytes)); + HIP_CHECK(hipMalloc(&v_dev, v_bytes)); + HIP_CHECK(hipMalloc(&o_dev, o_bytes)); + + if(is_group_mode) + { + std::vector sq_starts(batch + 1), sk_starts(batch + 1), sk_lens(batch); + for(int b = 0; b <= batch; ++b) + { + sq_starts[b] = b * seqlen_q; + sk_starts[b] = b * seqlen_k; + } + for(int b = 0; b < batch; ++b) + sk_lens[b] = seqlen_k; + + HIP_CHECK(hipMalloc(&seqstart_q_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&seqstart_k_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&seqlen_k_dev, batch * sizeof(int))); + HIP_CHECK(hipMemcpy( + seqstart_q_dev, sq_starts.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + seqstart_k_dev, sk_starts.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(seqlen_k_dev, sk_lens.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + } + + HIP_CHECK(hipMemcpy(q_dev, q_host, q_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(k_dev, k_host, k_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(v_dev, v_host, v_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(o_dev, 0, o_bytes)); + + if(bias_type_int > 0) + { + HIP_CHECK(hipMalloc(&bias_dev, bias_bytes)); + HIP_CHECK(hipMemset(bias_dev, 0, bias_bytes)); + } + if(has_lse) + { + HIP_CHECK(hipMalloc(&lse_dev_buf, lse_bytes)); + HIP_CHECK(hipMemset(lse_dev_buf, 0, lse_bytes)); + } + if(has_sink) + { + HIP_CHECK(hipMalloc(&sink_dev_fwd, nhead_q * sizeof(float))); + HIP_CHECK(hipMemset(sink_dev_fwd, 0, nhead_q * sizeof(float))); + } + + args.q_ptr = q_dev; + args.k_ptr = k_dev; + args.v_ptr = v_dev; + args.o_ptr = o_dev; + args.bias_ptr = bias_dev; + args.q_descale_ptr = nullptr; + args.k_descale_ptr = nullptr; + args.v_descale_ptr = nullptr; + args.rand_val_ptr = nullptr; + args.lse_ptr = lse_dev_buf; + args.seqstart_q_ptr = seqstart_q_dev; + args.seqstart_k_ptr = seqstart_k_dev; + args.seqlen_q_ptr = nullptr; + args.seqlen_k_ptr = seqlen_k_dev; + args.sink_ptr = sink_dev_fwd; + args.block_scale_seqstart_q_ptr = nullptr; + args.block_scale_seqstart_k_ptr = nullptr; + + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.batch = batch; + args.max_seqlen_q = seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.scale_s = scale; + args.logits_soft_cap = 0.0f; + + if(is_group_mode) + { + if(perm == 1) + { + // BHSD group: [1, head, total_tokens, dim] + args.stride_q = hdim_q; + args.stride_k = hdim_q; + args.stride_v = hdim_v; + args.stride_o = hdim_v; + args.nhead_stride_q = static_cast(seqlen_q) * hdim_q; + args.nhead_stride_k = static_cast(seqlen_k) * hdim_q; + args.nhead_stride_v = static_cast(seqlen_k) * hdim_v; + args.nhead_stride_o = static_cast(seqlen_q) * hdim_v; + } + else + { + // BSHD group: [total_tokens, head, dim] + args.stride_q = nhead_q * hdim_q; + args.stride_k = nhead_k * hdim_q; + args.stride_v = nhead_k * hdim_v; + args.stride_o = nhead_q * hdim_v; + args.nhead_stride_q = hdim_q; + args.nhead_stride_k = hdim_q; + args.nhead_stride_v = hdim_v; + args.nhead_stride_o = hdim_v; + } + args.batch_stride_q = 0; + args.batch_stride_k = 0; + args.batch_stride_v = 0; + args.batch_stride_o = 0; + } + else if(perm == 1) + { + // BHSD: [batch, head, seq, dim] + args.stride_q = hdim_q; + args.stride_k = hdim_q; + args.stride_v = hdim_v; + args.stride_o = hdim_v; + args.nhead_stride_q = static_cast(seqlen_q) * hdim_q; + args.nhead_stride_k = static_cast(seqlen_k) * hdim_q; + args.nhead_stride_v = static_cast(seqlen_k) * hdim_v; + args.nhead_stride_o = static_cast(seqlen_q) * hdim_v; + args.batch_stride_q = static_cast(nhead_q) * seqlen_q * hdim_q; + args.batch_stride_k = static_cast(nhead_k) * seqlen_k * hdim_q; + args.batch_stride_v = static_cast(nhead_k) * seqlen_k * hdim_v; + args.batch_stride_o = static_cast(nhead_q) * seqlen_q * hdim_v; + } + else + { + // BSHD: [batch, seq, head, dim] + args.stride_q = nhead_q * hdim_q; + args.stride_k = nhead_k * hdim_q; + args.stride_v = nhead_k * hdim_v; + args.stride_o = nhead_q * hdim_v; + args.nhead_stride_q = hdim_q; + args.nhead_stride_k = hdim_q; + args.nhead_stride_v = hdim_v; + args.nhead_stride_o = hdim_v; + args.batch_stride_q = static_cast(seqlen_q) * nhead_q * hdim_q; + args.batch_stride_k = static_cast(seqlen_k) * nhead_k * hdim_q; + args.batch_stride_v = static_cast(seqlen_k) * nhead_k * hdim_v; + args.batch_stride_o = static_cast(seqlen_q) * nhead_q * hdim_v; + } + args.stride_bias = (bias_type_int > 0) ? seqlen_k : 0; + args.stride_randval = 0; + args.nhead_stride_bias = (bias_type_int > 0) ? static_cast(seqlen_q) * seqlen_k : 0; + args.nhead_stride_randval = 0; + args.nhead_stride_lse = has_lse ? seqlen_q : 0; + args.nhead_stride_q_descale = 0; + args.nhead_stride_k_descale = 0; + args.nhead_stride_v_descale = 0; + args.batch_stride_bias = + (bias_type_int > 0) ? static_cast(nhead_q) * seqlen_q * seqlen_k : 0; + args.batch_stride_randval = 0; + args.batch_stride_lse = has_lse ? static_cast(nhead_q) * seqlen_q : 0; + args.batch_stride_q_descale = 0; + args.batch_stride_k_descale = 0; + args.batch_stride_v_descale = 0; + + args.window_size_left = window_left; + args.window_size_right = window_right; + args.sink_size = 0; + args.mask_type = mask_type_int; + args.min_seqlen_q = 0; + args.p_drop = has_dropout ? 0.2f : 0.0f; + args.s_randval = false; + args.drop_seed_offset = has_dropout ? std::make_pair(uint64_t(1), uint64_t(0)) + : std::make_pair(uint64_t(0), uint64_t(0)); + args.block_scale_size_q = 0; + args.block_scale_size_kv = 0; + + try + { + auto invocation = FmhaInvocation::make(std::move(traits), std::move(args)); + if(g_registry->size() == 1) + elapsed = run_single_kernel(invocation); + else + elapsed = g_dispatcher->run_fwd(std::get(invocation.traits), + std::get(invocation.args), + nullptr); + } + catch(const std::exception& e) + { + fprintf(stderr, "FMHA_FWD_ERR: %s\n", e.what()); + rc = -2; + goto cleanup; + } + catch(...) + { + fprintf(stderr, "FMHA_ERR: unknown\n"); + rc = -2; + goto cleanup; + } + + { + hipError_t cpy_err = hipMemcpy(o_host, o_dev, o_bytes, hipMemcpyDeviceToHost); + if(cpy_err != hipSuccess) + rc = -1; + } + + if(time_ms_out) + *time_ms_out = elapsed; + +cleanup: + safe_hip_free(q_dev); + safe_hip_free(k_dev); + safe_hip_free(v_dev); + safe_hip_free(o_dev); + safe_hip_free(bias_dev); + safe_hip_free(lse_dev_buf); + safe_hip_free(sink_dev_fwd); + safe_hip_free(seqstart_q_dev); + safe_hip_free(seqstart_k_dev); + safe_hip_free(seqlen_k_dev); + + return rc; +} + +int fmha_dispatcher_run_bwd(const void* q_host, + const void* k_host, + const void* v_host, + const void* o_host, + const void* lse_host, + const void* do_host, + void* dq_host, + void* dk_host, + void* dv_host, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + const char* data_type_str, + int mask_type_int, + int bias_type_int, + int has_dropout, + int has_dbias, + int is_deterministic, + int is_group_mode, + int is_store_randval, + int tile_n0, + float* time_ms_out) +{ + if(!g_initialized) + return -1; + + const int in_bytes = dtype_input_bytes(data_type_str); + const int out_bytes = dtype_output_bytes(data_type_str); + + int rc = 0; + const int64_t q_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_q * in_bytes; + const int64_t k_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_q * in_bytes; + const int64_t v_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_v * in_bytes; + const int64_t o_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_v * out_bytes; + const int64_t do_bytes = o_bytes; + const int64_t dq_bytes = q_bytes; + const int64_t dk_bytes = k_bytes; + const int64_t dv_bytes = v_bytes; + const int64_t lse_bytes = static_cast(batch) * nhead_q * seqlen_q * sizeof(float); + const int64_t d_bytes = static_cast(batch) * nhead_q * seqlen_q * sizeof(float); + const bool bwd_grp = (is_group_mode != 0); + const int kN0 = (tile_n0 > 0) ? tile_n0 : 128; + const int bwd_nsplits = is_deterministic + ? ((seqlen_k + kN0 - 1) / kN0) // ceil(max_seqlen_k / kN0) + : 1; + const int64_t bwd_shape_sq = bwd_grp ? static_cast(batch) * seqlen_q : seqlen_q; + const int64_t bwd_shape_sk = bwd_grp ? static_cast(batch) * seqlen_k : seqlen_k; + const int64_t bwd_shape_batch = bwd_grp ? 1 : batch; + const int64_t dq_acc_bytes = + bwd_shape_batch * nhead_q * bwd_nsplits * bwd_shape_sq * hdim_q * sizeof(float); + const int64_t split_stride_dq_acc_val = bwd_shape_sq * hdim_q; + float elapsed = 0.0f; + + void *q_dev = nullptr, *k_dev = nullptr, *v_dev = nullptr, *o_dev = nullptr; + void *lse_dev = nullptr, *do_dev = nullptr, *d_dev = nullptr; + void *dq_dev = nullptr, *dk_dev = nullptr, *dv_dev = nullptr, *dq_acc_dev = nullptr; + void *bwd_seqstart_q_dev = nullptr, *bwd_seqstart_k_dev = nullptr; + void *bwd_seqlen_k_dev = nullptr, *bwd_seqlen_q_dev = nullptr; + void *bwd_bias_dev = nullptr, *bwd_randval_dev = nullptr, *bwd_dbias_dev = nullptr; + + std::vector bwd_sq(batch + 1), bwd_sk(batch + 1), bwd_skl(batch, seqlen_k), + bwd_sql(batch, seqlen_q); + if(bwd_grp) + { + for(int b = 0; b <= batch; ++b) + { + bwd_sq[b] = b * seqlen_q; + bwd_sk[b] = b * seqlen_k; + } + } + + fmha_bwd_traits traits{}; + traits.seqlen_q = bwd_shape_sq; + traits.seqlen_k = bwd_shape_sk; + traits.batch = batch; + traits.max_seqlen_q = seqlen_q; + traits.max_seqlen_k = seqlen_k; + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.nhead_q = nhead_q; + traits.nhead_k = nhead_k; + traits.data_type = data_type_str ? data_type_str : "fp16"; + traits.is_group_mode = (is_group_mode != 0); + traits.mask_type = static_cast(mask_type_int); + traits.bias_type = static_cast(bias_type_int); + traits.has_dbias = (has_dbias != 0); + traits.has_dropout = (has_dropout != 0); + traits.is_store_randval = (is_store_randval != 0); + traits.is_deterministic = (is_deterministic != 0); + + fmha_bwd_args args{}; + + HIP_CHECK(hipMalloc(&q_dev, q_bytes)); + HIP_CHECK(hipMalloc(&k_dev, k_bytes)); + HIP_CHECK(hipMalloc(&v_dev, v_bytes)); + HIP_CHECK(hipMalloc(&o_dev, o_bytes)); + HIP_CHECK(hipMalloc(&lse_dev, lse_bytes)); + HIP_CHECK(hipMalloc(&do_dev, do_bytes)); + HIP_CHECK(hipMalloc(&d_dev, d_bytes)); + HIP_CHECK(hipMalloc(&dq_dev, dq_bytes)); + HIP_CHECK(hipMalloc(&dk_dev, dk_bytes)); + HIP_CHECK(hipMalloc(&dv_dev, dv_bytes)); + HIP_CHECK(hipMalloc(&dq_acc_dev, dq_acc_bytes)); + + if(bias_type_int > 0) + { + const int64_t bias_bytes = + (bias_type_int == 2) + ? static_cast(batch) * nhead_q * sizeof(float) + : static_cast(batch) * nhead_q * seqlen_q * seqlen_k * out_bytes; + HIP_CHECK(hipMalloc(&bwd_bias_dev, bias_bytes)); + HIP_CHECK(hipMemset(bwd_bias_dev, 0, bias_bytes)); + } + if(has_dropout) + { + const int64_t rv_bytes = + static_cast(batch) * nhead_q * seqlen_q * seqlen_k * sizeof(int8_t); + HIP_CHECK(hipMalloc(&bwd_randval_dev, rv_bytes)); + HIP_CHECK(hipMemset(bwd_randval_dev, 0, rv_bytes)); + } + if(has_dbias) + { + const int64_t dbias_bytes = + static_cast(batch) * nhead_q * seqlen_q * seqlen_k * out_bytes; + HIP_CHECK(hipMalloc(&bwd_dbias_dev, dbias_bytes)); + HIP_CHECK(hipMemset(bwd_dbias_dev, 0, dbias_bytes)); + } + + if(bwd_grp) + { + HIP_CHECK(hipMalloc(&bwd_seqstart_q_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&bwd_seqstart_k_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&bwd_seqlen_k_dev, batch * sizeof(int))); + HIP_CHECK(hipMalloc(&bwd_seqlen_q_dev, batch * sizeof(int))); + HIP_CHECK(hipMemcpy( + bwd_seqstart_q_dev, bwd_sq.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + bwd_seqstart_k_dev, bwd_sk.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + bwd_seqlen_k_dev, bwd_skl.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + bwd_seqlen_q_dev, bwd_sql.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + } + + if(bwd_grp) + { + // Group mode: kernel uses [1, nhead, total_tokens, hdim] layout. + // Zero all buffers (data content doesn't affect benchmarking timing). + HIP_CHECK(hipMemset(q_dev, 0, q_bytes)); + HIP_CHECK(hipMemset(k_dev, 0, k_bytes)); + HIP_CHECK(hipMemset(v_dev, 0, v_bytes)); + HIP_CHECK(hipMemset(o_dev, 0, o_bytes)); + HIP_CHECK(hipMemset(lse_dev, 0, lse_bytes)); + HIP_CHECK(hipMemset(do_dev, 0, do_bytes)); + } + else + { + HIP_CHECK(hipMemcpy(q_dev, q_host, q_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(k_dev, k_host, k_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(v_dev, v_host, v_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(o_dev, o_host, o_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(lse_dev, lse_host, lse_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(do_dev, do_host, do_bytes, hipMemcpyHostToDevice)); + } + // d_ptr is computed by dot_do_o GPU kernel (stage 1 of BWD pipeline). + // Zero-initialize; dot_do_o will fill it before dq_dk_dv reads it. + HIP_CHECK(hipMemset(d_dev, 0, d_bytes)); + HIP_CHECK(hipMemset(dq_dev, 0, dq_bytes)); + HIP_CHECK(hipMemset(dk_dev, 0, dk_bytes)); + HIP_CHECK(hipMemset(dv_dev, 0, dv_bytes)); + HIP_CHECK(hipMemset(dq_acc_dev, 0, dq_acc_bytes)); + + args.q_ptr = q_dev; + args.k_ptr = k_dev; + args.v_ptr = v_dev; + args.bias_ptr = bwd_bias_dev; + args.o_ptr = o_dev; + args.lse_ptr = lse_dev; + args.do_ptr = do_dev; + args.d_ptr = d_dev; + args.rand_val_ptr = bwd_randval_dev; + args.dq_ptr = dq_dev; + args.dk_ptr = dk_dev; + args.dv_ptr = dv_dev; + args.dbias_ptr = bwd_dbias_dev; + args.dq_acc_ptr = dq_acc_dev; + + args.seqlen_q = bwd_shape_sq; + args.seqlen_k = bwd_shape_sk; + args.batch = batch; + args.max_seqlen_q = seqlen_q; + args.max_seqlen_k = seqlen_k; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.scale = scale; + + // BHSD strides -- unified for both group and batch mode. + // CK uses shape_seqlen_q/k (= total_tokens for group, = per-seq for batch) + // for ALL stride computations, including batch_stride. + args.stride_q = hdim_q; + args.stride_k = hdim_q; + args.stride_v = hdim_v; + args.stride_bias = 0; + args.stride_o = hdim_v; + args.stride_randval = 0; + args.stride_do = hdim_v; + args.stride_dq_acc = hdim_q; + args.stride_dq = hdim_q; + args.stride_dk = hdim_q; + args.stride_dv = hdim_v; + args.stride_dbias = 0; + args.nhead_stride_q = bwd_shape_sq * hdim_q; + args.nhead_stride_k = bwd_shape_sk * hdim_q; + args.nhead_stride_v = bwd_shape_sk * hdim_v; + args.nhead_stride_bias = 0; + args.nhead_stride_o = bwd_shape_sq * hdim_v; + args.nhead_stride_randval = 0; + args.nhead_stride_do = bwd_shape_sq * hdim_v; + args.nhead_stride_lsed = bwd_shape_sq; + args.nhead_stride_dq_acc = + static_cast(split_stride_dq_acc_val) * bwd_nsplits; + args.nhead_stride_dq = bwd_shape_sq * hdim_q; + args.nhead_stride_dk = bwd_shape_sk * hdim_q; + args.nhead_stride_dv = bwd_shape_sk * hdim_v; + args.nhead_stride_dbias = 0; + args.batch_stride_q = static_cast(nhead_q) * bwd_shape_sq * hdim_q; + args.batch_stride_k = static_cast(nhead_k) * bwd_shape_sk * hdim_q; + args.batch_stride_v = static_cast(nhead_k) * bwd_shape_sk * hdim_v; + args.batch_stride_bias = 0; + args.batch_stride_o = static_cast(nhead_q) * bwd_shape_sq * hdim_v; + args.batch_stride_randval = 0; + args.batch_stride_do = static_cast(nhead_q) * bwd_shape_sq * hdim_v; + args.batch_stride_lsed = static_cast(nhead_q) * bwd_shape_sq; + args.batch_stride_dq_acc = + static_cast(nhead_q) * split_stride_dq_acc_val * bwd_nsplits; + args.batch_stride_dq = static_cast(nhead_q) * bwd_shape_sq * hdim_q; + args.batch_stride_dk = static_cast(nhead_k) * bwd_shape_sk * hdim_q; + args.batch_stride_dv = static_cast(nhead_k) * bwd_shape_sk * hdim_v; + args.batch_stride_dbias = 0; + args.split_stride_dq_acc = split_stride_dq_acc_val; + + args.seqstart_q_ptr = bwd_seqstart_q_dev; + args.seqstart_k_ptr = bwd_seqstart_k_dev; + args.seqlen_q_ptr = bwd_seqlen_q_dev; + args.seqlen_k_ptr = bwd_seqlen_k_dev; + args.cu_seqlen_q_ptr = nullptr; + args.cu_seqlen_k_ptr = nullptr; + + args.window_size_left = -1; + args.window_size_right = -1; + args.mask_type = mask_type_int; + args.p_drop = has_dropout ? 0.2f : 0.0f; + args.p_undrop = has_dropout ? (1.0f / (1.0f - 0.2f)) : 1.0f; + args.drop_seed_offset = has_dropout ? std::make_pair(uint64_t(1), uint64_t(0)) + : std::make_pair(uint64_t(0), uint64_t(0)); + + try + { + auto invocation = FmhaInvocation::make(std::move(traits), std::move(args)); + if(g_registry->size() == 1) + elapsed = run_single_kernel(invocation); + else + elapsed = g_dispatcher->run_bwd(std::get(invocation.traits), + std::get(invocation.args), + nullptr); + } + catch(const std::exception& e) + { + fprintf(stderr, "FMHA_BWD_ERR: %s\n", e.what()); + rc = -2; + goto cleanup; + } + catch(...) + { + fprintf(stderr, "FMHA_BWD_ERR: unknown\n"); + rc = -2; + goto cleanup; + } + + { + hipError_t e1 = hipMemcpy(dq_host, dq_dev, dq_bytes, hipMemcpyDeviceToHost); + hipError_t e2 = hipMemcpy(dk_host, dk_dev, dk_bytes, hipMemcpyDeviceToHost); + hipError_t e3 = hipMemcpy(dv_host, dv_dev, dv_bytes, hipMemcpyDeviceToHost); + if(e1 != hipSuccess || e2 != hipSuccess || e3 != hipSuccess) + rc = -1; + } + + if(time_ms_out) + *time_ms_out = elapsed; + +cleanup: + safe_hip_free(q_dev); + safe_hip_free(k_dev); + safe_hip_free(v_dev); + safe_hip_free(o_dev); + safe_hip_free(lse_dev); + safe_hip_free(do_dev); + safe_hip_free(d_dev); + safe_hip_free(dq_dev); + safe_hip_free(dk_dev); + safe_hip_free(dv_dev); + safe_hip_free(dq_acc_dev); + safe_hip_free(bwd_seqstart_q_dev); + safe_hip_free(bwd_seqstart_k_dev); + safe_hip_free(bwd_seqlen_k_dev); + safe_hip_free(bwd_seqlen_q_dev); + safe_hip_free(bwd_bias_dev); + safe_hip_free(bwd_randval_dev); + safe_hip_free(bwd_dbias_dev); + + return rc; +} + +// --------------------------------------------------------------------------- +// Split-KV forward: 2-stage (split + combine) +// Allocates o_acc / lse_acc internally for the split stage. +// --------------------------------------------------------------------------- +int fmha_dispatcher_run_splitkv(const void* q_host, + const void* k_host, + const void* v_host, + void* o_host, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + int mask_type_int, + int num_splits, + int is_v_rowmajor, + const char* data_type_str, + int has_lse, + int is_group_mode, + int perm, + int has_logits, + int bias_type_int, + int has_sink, + int paged_kv, + int page_block_size, + int window_left, + int window_right, + float* time_ms_out) +{ + if(!g_initialized) + return -1; + + const int in_bytes = dtype_input_bytes(data_type_str); + const int out_bytes = dtype_output_bytes(data_type_str); + + int rc = 0; + const int64_t q_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_q * in_bytes; + const int64_t k_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_q * in_bytes; + const int64_t v_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_v * in_bytes; + const int64_t o_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_v * out_bytes; + const int64_t o_acc_bytes = + static_cast(num_splits) * batch * nhead_q * seqlen_q * hdim_v * sizeof(float); + const int64_t lse_bytes = static_cast(batch) * nhead_q * seqlen_q * sizeof(float); + const int64_t lse_acc_bytes = + static_cast(num_splits) * batch * nhead_q * seqlen_q * sizeof(float); + float elapsed = 0.0f; + + const bool grp = (is_group_mode != 0); + + const bool is_paged = (paged_kv != 0); + if(is_paged && page_block_size <= 0) + page_block_size = 64; + const int pages_per_seq = is_paged ? (seqlen_k + page_block_size - 1) / page_block_size : 0; + const int total_pages = is_paged ? batch * pages_per_seq : 0; + + void *q_dev = nullptr, *k_dev = nullptr, *v_dev = nullptr, *o_dev = nullptr; + void *o_acc_dev = nullptr, *lse_dev = nullptr, *lse_acc_dev = nullptr; + void *seqstart_q_dev = nullptr, *seqstart_k_dev = nullptr, *seqlen_k_dev = nullptr; + void *block_table_dev = nullptr, *bias_dev = nullptr, *sink_dev = nullptr; + + // Declare vectors before any HIP_CHECK to avoid goto-over-init + std::vector sq_starts(batch + 1), sk_starts(batch + 1), sk_lens(batch, seqlen_k); + std::vector block_table(total_pages); + for(int i = 0; i < total_pages; ++i) + block_table[i] = i; + if(grp) + { + for(int b = 0; b <= batch; ++b) + { + sq_starts[b] = b * seqlen_q; + sk_starts[b] = b * seqlen_k; + } + } + + fmha_fwd_splitkv_traits traits{}; + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type_str ? data_type_str : "fp16"; + traits.is_group_mode = grp; + traits.is_v_rowmajor = (is_v_rowmajor != 0); + traits.has_logits_soft_cap = (has_logits != 0); + traits.mask_type = static_cast(mask_type_int); + traits.bias_type = static_cast(bias_type_int); + traits.has_lse = (has_lse != 0); + traits.has_sink = (has_sink != 0); + + fmha_fwd_splitkv_args args{}; + + HIP_CHECK(hipMalloc(&q_dev, q_bytes)); + HIP_CHECK(hipMalloc(&k_dev, k_bytes)); + HIP_CHECK(hipMalloc(&v_dev, v_bytes)); + HIP_CHECK(hipMalloc(&o_dev, o_bytes)); + HIP_CHECK(hipMalloc(&o_acc_dev, o_acc_bytes)); + HIP_CHECK(hipMalloc(&lse_dev, lse_bytes)); + HIP_CHECK(hipMalloc(&lse_acc_dev, lse_acc_bytes)); + + if(is_paged) + { + HIP_CHECK(hipMalloc(&block_table_dev, total_pages * sizeof(int))); + HIP_CHECK(hipMemcpy( + block_table_dev, block_table.data(), total_pages * sizeof(int), hipMemcpyHostToDevice)); + } + + if(grp || is_paged) + { + HIP_CHECK(hipMalloc(&seqstart_q_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&seqstart_k_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&seqlen_k_dev, batch * sizeof(int))); + HIP_CHECK(hipMemcpy( + seqstart_q_dev, sq_starts.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + seqstart_k_dev, sk_starts.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(seqlen_k_dev, sk_lens.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + } + + HIP_CHECK(hipMemcpy(q_dev, q_host, q_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(k_dev, k_host, k_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(v_dev, v_host, v_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(o_dev, 0, o_bytes)); + HIP_CHECK(hipMemset(o_acc_dev, 0, o_acc_bytes)); + HIP_CHECK(hipMemset(lse_dev, 0, lse_bytes)); + HIP_CHECK(hipMemset(lse_acc_dev, 0, lse_acc_bytes)); + + if(bias_type_int > 0) + { + const int64_t bias_bytes = + (bias_type_int == 2) // alibi: [batch, nhead] slope values + ? static_cast(batch) * nhead_q * sizeof(float) + : static_cast(batch) * nhead_q * seqlen_q * seqlen_k * out_bytes; + HIP_CHECK(hipMalloc(&bias_dev, bias_bytes)); + HIP_CHECK(hipMemset(bias_dev, 0, bias_bytes)); + } + if(has_sink) + { + HIP_CHECK(hipMalloc(&sink_dev, nhead_q * sizeof(float))); + HIP_CHECK(hipMemset(sink_dev, 0, nhead_q * sizeof(float))); + } + + args.q_ptr = q_dev; + args.k_ptr = k_dev; + args.v_ptr = v_dev; + args.bias_ptr = bias_dev; + args.lse_acc_ptr = lse_acc_dev; + args.o_acc_ptr = o_acc_dev; + args.lse_ptr = lse_dev; + args.o_ptr = o_dev; + args.block_table_ptr = block_table_dev; + args.batch_stride_block_table = is_paged ? pages_per_seq : 0; + args.page_block_size = is_paged ? page_block_size : 0; + args.is_gappy = false; + args.cache_batch_idx = nullptr; + args.seqstart_q_ptr = seqstart_q_dev; + args.seqstart_k_ptr = seqstart_k_dev; + args.seqlen_k_ptr = seqlen_k_dev; + args.sink_ptr = sink_dev; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.batch = batch; + args.max_seqlen_q = seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.num_splits = num_splits; + args.scale_s = scale; + args.scale_p = 1.0f; + args.scale_o = 1.0f; + args.logits_soft_cap = 0.0f; + + if(grp) + { + if(perm == 1) + { + // BHSD group: [1, head, total_tokens, dim] + args.stride_q = hdim_q; + args.stride_k = hdim_q; + args.stride_v = hdim_v; + args.stride_o = hdim_v; + args.nhead_stride_q = static_cast(seqlen_q) * hdim_q; + args.nhead_stride_k = static_cast(seqlen_k) * hdim_q; + args.nhead_stride_v = static_cast(seqlen_k) * hdim_v; + args.nhead_stride_o = static_cast(seqlen_q) * hdim_v; + } + else + { + // BSHD group: [total_tokens, nhead, hdim] + args.stride_q = nhead_q * hdim_q; + args.stride_k = nhead_k * hdim_q; + args.stride_v = nhead_k * hdim_v; + args.stride_o = nhead_q * hdim_v; + args.nhead_stride_q = hdim_q; + args.nhead_stride_k = hdim_q; + args.nhead_stride_v = hdim_v; + args.nhead_stride_o = hdim_v; + } + args.stride_bias = 0; + args.stride_o_acc = hdim_v; + args.nhead_stride_bias = 0; + args.nhead_stride_lse = seqlen_q; + args.nhead_stride_lse_acc = static_cast(num_splits) * seqlen_q; + args.nhead_stride_o_acc = static_cast(num_splits) * seqlen_q * hdim_v; + args.batch_stride_q = 0; + args.batch_stride_k = 0; + args.batch_stride_v = 0; + args.batch_stride_bias = 0; + args.batch_stride_lse = static_cast(nhead_q) * seqlen_q; + args.batch_stride_lse_acc = static_cast(nhead_q) * num_splits * seqlen_q; + args.batch_stride_o_acc = static_cast(nhead_q) * num_splits * seqlen_q * hdim_v; + args.batch_stride_o = 0; + } + else + { + // BHSD strides (with paged K/V if applicable) + const int kv_seq = is_paged ? page_block_size : seqlen_k; + args.stride_q = hdim_q; + args.stride_k = hdim_q; + args.stride_v = hdim_v; + args.stride_bias = 0; + args.stride_o_acc = hdim_v; + args.stride_o = hdim_v; + args.nhead_stride_q = static_cast(seqlen_q) * hdim_q; + args.nhead_stride_k = static_cast(kv_seq) * hdim_q; + args.nhead_stride_v = static_cast(kv_seq) * hdim_v; + args.nhead_stride_bias = 0; + args.nhead_stride_lse = seqlen_q; + args.nhead_stride_lse_acc = static_cast(num_splits) * seqlen_q; + args.nhead_stride_o_acc = static_cast(num_splits) * seqlen_q * hdim_v; + args.nhead_stride_o = static_cast(seqlen_q) * hdim_v; + args.batch_stride_q = static_cast(nhead_q) * seqlen_q * hdim_q; + args.batch_stride_k = static_cast(nhead_k) * kv_seq * hdim_q; + args.batch_stride_v = static_cast(nhead_k) * kv_seq * hdim_v; + args.batch_stride_bias = 0; + args.batch_stride_lse = static_cast(nhead_q) * seqlen_q; + args.batch_stride_lse_acc = static_cast(nhead_q) * num_splits * seqlen_q; + args.batch_stride_o_acc = static_cast(nhead_q) * num_splits * seqlen_q * hdim_v; + args.batch_stride_o = static_cast(nhead_q) * seqlen_q * hdim_v; + } + args.split_stride_lse_acc = seqlen_q; + args.split_stride_o_acc = static_cast(seqlen_q) * hdim_v; + args.window_size_left = window_left; + args.window_size_right = window_right; + args.sink_size = 0; + args.mask_type = mask_type_int; + + try + { + auto invocation = FmhaInvocation::make(std::move(traits), std::move(args)); + if(g_registry->size() == 1) + elapsed = run_single_kernel(invocation); + else + elapsed = + g_dispatcher->run_fwd_splitkv(std::get(invocation.traits), + std::get(invocation.args), + nullptr); + } + catch(const std::exception& e) + { + fprintf(stderr, "FMHA_SPLITKV_ERR: %s\n", e.what()); + rc = -2; + goto cleanup; + } + catch(...) + { + fprintf(stderr, "FMHA_SPLITKV_ERR: unknown\n"); + rc = -2; + goto cleanup; + } + + { + hipError_t cpy = hipMemcpy(o_host, o_dev, o_bytes, hipMemcpyDeviceToHost); + if(cpy != hipSuccess) + rc = -1; + } + if(time_ms_out) + *time_ms_out = elapsed; + +cleanup: + safe_hip_free(q_dev); + safe_hip_free(k_dev); + safe_hip_free(v_dev); + safe_hip_free(o_dev); + safe_hip_free(o_acc_dev); + safe_hip_free(lse_dev); + safe_hip_free(lse_acc_dev); + safe_hip_free(seqstart_q_dev); + safe_hip_free(seqstart_k_dev); + safe_hip_free(seqlen_k_dev); + safe_hip_free(block_table_dev); + safe_hip_free(bias_dev); + safe_hip_free(sink_dev); + return rc; +} + +// --------------------------------------------------------------------------- +// Paged-KV forward: Q in standard layout, K/V in paged blocks +// Creates a trivial contiguous page table for benchmarking. +// --------------------------------------------------------------------------- +int fmha_dispatcher_run_pagedkv(const void* q_host, + const void* k_host, + const void* v_host, + void* o_host, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + int mask_type_int, + int page_block_size, + int is_v_rowmajor, + const char* data_type_str, + int has_lse, + int has_logits, + int has_sink, + int skip_min_seqlen_q, + int bias_type_int, + float* time_ms_out) +{ + if(!g_initialized) + return -1; + + const int in_bytes = dtype_input_bytes(data_type_str); + const int out_bytes = dtype_output_bytes(data_type_str); + + int rc = 0; + const int64_t q_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_q * in_bytes; + const int64_t k_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_q * in_bytes; + const int64_t v_bytes = static_cast(batch) * nhead_k * seqlen_k * hdim_v * in_bytes; + const int64_t o_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_v * out_bytes; + const int64_t lse_bytes = static_cast(batch) * nhead_q * seqlen_q * sizeof(float); + float elapsed = 0.0f; + + if(page_block_size <= 0) + page_block_size = 64; + const int pages_per_seq = (seqlen_k + page_block_size - 1) / page_block_size; + const int total_pages = batch * pages_per_seq; + + void *q_dev = nullptr, *k_dev = nullptr, *v_dev = nullptr, *o_dev = nullptr; + void *lse_dev = nullptr, *block_table_dev = nullptr; + void *seqlen_k_dev = nullptr, *seqstart_q_dev = nullptr, *seqstart_k_dev = nullptr; + void *sink_dev = nullptr, *bias_dev_pkv = nullptr; + + // Declare vectors before any HIP_CHECK to avoid goto-over-init + std::vector block_table(total_pages); + for(int i = 0; i < total_pages; ++i) + block_table[i] = i; + std::vector seqlen_k_vec(batch, seqlen_k); + std::vector sq_starts(batch + 1), sk_starts(batch + 1); + for(int b = 0; b <= batch; ++b) + { + sq_starts[b] = b * seqlen_q; + sk_starts[b] = b * seqlen_k; + } + + fmha_fwd_pagedkv_traits traits{}; + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type_str ? data_type_str : "fp16"; + traits.is_group_mode = true; + traits.is_v_rowmajor = (is_v_rowmajor != 0); + traits.has_logits_soft_cap = (has_logits != 0); + traits.mask_type = static_cast(mask_type_int); + traits.bias_type = static_cast(bias_type_int); + traits.has_lse = (has_lse != 0); + traits.use_pagedkv = true; + traits.has_sink = (has_sink != 0); + traits.skip_min_seqlen_q = (skip_min_seqlen_q != 0); + + fmha_fwd_pagedkv_args args{}; + + HIP_CHECK(hipMalloc(&q_dev, q_bytes)); + HIP_CHECK(hipMalloc(&k_dev, k_bytes)); + HIP_CHECK(hipMalloc(&v_dev, v_bytes)); + HIP_CHECK(hipMalloc(&o_dev, o_bytes)); + + HIP_CHECK(hipMalloc(&block_table_dev, total_pages * sizeof(int))); + HIP_CHECK(hipMemcpy( + block_table_dev, block_table.data(), total_pages * sizeof(int), hipMemcpyHostToDevice)); + + HIP_CHECK(hipMalloc(&seqlen_k_dev, batch * sizeof(int))); + HIP_CHECK( + hipMemcpy(seqlen_k_dev, seqlen_k_vec.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + + // Group mode needs seqstart pointers + HIP_CHECK(hipMalloc(&seqstart_q_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&seqstart_k_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMemcpy( + seqstart_q_dev, sq_starts.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + seqstart_k_dev, sk_starts.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + + if(has_lse) + { + HIP_CHECK(hipMalloc(&lse_dev, lse_bytes)); + HIP_CHECK(hipMemset(lse_dev, 0, lse_bytes)); + } + if(has_sink) + { + HIP_CHECK(hipMalloc(&sink_dev, nhead_q * sizeof(float))); + HIP_CHECK(hipMemset(sink_dev, 0, nhead_q * sizeof(float))); + } + + if(bias_type_int > 0) + { + const int64_t bias_bytes = + (bias_type_int == 2) + ? static_cast(batch) * nhead_q * sizeof(float) + : static_cast(batch) * nhead_q * seqlen_q * seqlen_k * out_bytes; + HIP_CHECK(hipMalloc(&bias_dev_pkv, bias_bytes)); + HIP_CHECK(hipMemset(bias_dev_pkv, 0, bias_bytes)); + } + + HIP_CHECK(hipMemcpy(q_dev, q_host, q_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(k_dev, k_host, k_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(v_dev, v_host, v_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(o_dev, 0, o_bytes)); + + args.q_ptr = q_dev; + args.k_ptr = k_dev; + args.v_ptr = v_dev; + args.bias_ptr = bias_dev_pkv; + args.lse_ptr = lse_dev; + args.o_ptr = o_dev; + args.block_table_ptr = block_table_dev; + args.batch_stride_block_table = pages_per_seq; + args.page_block_size = page_block_size; + args.is_gappy = false; + args.cache_batch_idx = nullptr; + args.seqstart_q_ptr = seqstart_q_dev; + args.seqstart_k_ptr = seqstart_k_dev; + args.seqlen_k_ptr = seqlen_k_dev; + args.sink_ptr = sink_dev; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.batch = batch; + args.max_seqlen_q = seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.scale_s = scale; + args.scale_p = 1.0f; + args.scale_o = 1.0f; + args.logits_soft_cap = 0.0f; + + // Pagedkv is always group mode: Q=[total_tokens, nhead, hdim], K/V=[pages, nhead, pbs, hdim] + args.stride_q = nhead_q * hdim_q; + args.stride_k = hdim_q; + args.stride_v = hdim_v; + args.stride_bias = 0; + args.stride_o = nhead_q * hdim_v; + args.nhead_stride_q = hdim_q; + args.nhead_stride_k = static_cast(page_block_size) * hdim_q; + args.nhead_stride_v = static_cast(page_block_size) * hdim_v; + args.nhead_stride_bias = 0; + args.nhead_stride_lse = seqlen_q; + args.nhead_stride_o = hdim_v; + args.batch_stride_q = 0; + args.batch_stride_k = static_cast(nhead_k) * page_block_size * hdim_q; + args.batch_stride_v = static_cast(nhead_k) * page_block_size * hdim_v; + args.batch_stride_bias = 0; + args.batch_stride_lse = static_cast(nhead_q) * seqlen_q; + args.batch_stride_o = 0; + args.window_size_left = -1; + args.window_size_right = -1; + args.sink_size = 0; + args.mask_type = mask_type_int; + args.min_seqlen_q = 0; + + try + { + auto invocation = FmhaInvocation::make(std::move(traits), std::move(args)); + if(g_registry->size() == 1) + elapsed = run_single_kernel(invocation); + else + elapsed = + g_dispatcher->run_fwd_pagedkv(std::get(invocation.traits), + std::get(invocation.args), + nullptr); + } + catch(const std::exception& e) + { + fprintf(stderr, "FMHA_PAGEDKV_ERR: %s\n", e.what()); + rc = -2; + goto cleanup; + } + catch(...) + { + fprintf(stderr, "FMHA_PAGEDKV_ERR: unknown\n"); + rc = -2; + goto cleanup; + } + + { + hipError_t cpy = hipMemcpy(o_host, o_dev, o_bytes, hipMemcpyDeviceToHost); + if(cpy != hipSuccess) + rc = -1; + } + if(time_ms_out) + *time_ms_out = elapsed; + +cleanup: + safe_hip_free(q_dev); + safe_hip_free(k_dev); + safe_hip_free(v_dev); + safe_hip_free(o_dev); + safe_hip_free(lse_dev); + safe_hip_free(block_table_dev); + safe_hip_free(seqlen_k_dev); + safe_hip_free(seqstart_q_dev); + safe_hip_free(seqstart_k_dev); + safe_hip_free(sink_dev); + safe_hip_free(bias_dev_pkv); + return rc; +} + +// --------------------------------------------------------------------------- +// Append-KV: appends knew/vnew into K/V cache, optionally with RoPE +// --------------------------------------------------------------------------- +int fmha_dispatcher_run_appendkv(const void* q_host, + const void* knew_host, + const void* vnew_host, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_knew, + int hdim_q, + int hdim_v, + int is_v_rowmajor, + int rope_type_int, + int paged_kv, + int page_block_size, + const char* data_type_str, + float* time_ms_out) +{ + if(!g_initialized) + return -1; + + const int in_bytes = dtype_input_bytes(data_type_str); + int rc = 0; + + const int seqlen_k = seqlen_q + seqlen_knew; + const bool has_rope = (rope_type_int != 0); + const int rotary_dim = has_rope ? hdim_q : 0; + const bool akv_paged = (paged_kv != 0); + if(akv_paged && page_block_size <= 0) + page_block_size = 64; + const int akv_pps = akv_paged ? (seqlen_k + page_block_size - 1) / page_block_size : 0; + const int akv_tp = akv_paged ? batch * akv_pps : 0; + const int kv_s = akv_paged ? page_block_size : seqlen_k; + + const int64_t q_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_q * in_bytes; + const int64_t knew_bytes = + static_cast(batch) * nhead_k * seqlen_knew * hdim_q * in_bytes; + const int64_t vnew_bytes = + static_cast(batch) * nhead_k * seqlen_knew * hdim_v * in_bytes; + const int64_t k_bytes = + akv_paged ? static_cast(akv_tp) * nhead_k * page_block_size * hdim_q * in_bytes + : static_cast(batch) * nhead_k * seqlen_k * hdim_q * in_bytes; + const int64_t v_bytes = + akv_paged ? static_cast(akv_tp) * nhead_k * page_block_size * hdim_v * in_bytes + : static_cast(batch) * nhead_k * seqlen_k * hdim_v * in_bytes; + float elapsed = 0.0f; + + void *q_dev = nullptr, *k_dev = nullptr, *v_dev = nullptr; + void *knew_dev = nullptr, *vnew_dev = nullptr; + void *seqlen_k_dev = nullptr, *rotary_cos_dev = nullptr, *rotary_sin_dev = nullptr; + void* akv_block_table_dev = nullptr; + + fmha_fwd_appendkv_traits traits{}; + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type_str ? data_type_str : "fp16"; + traits.is_v_rowmajor = (is_v_rowmajor != 0); + traits.rope_type = static_cast(rope_type_int); + + std::vector sk_vec(batch, seqlen_k - seqlen_knew); + std::vector akv_bt(akv_tp); + for(int i = 0; i < akv_tp; ++i) + akv_bt[i] = i; + + fmha_fwd_appendkv_args args{}; + + HIP_CHECK(hipMalloc(&q_dev, q_bytes)); + HIP_CHECK(hipMalloc(&k_dev, k_bytes)); + HIP_CHECK(hipMalloc(&v_dev, v_bytes)); + HIP_CHECK(hipMalloc(&knew_dev, knew_bytes)); + HIP_CHECK(hipMalloc(&vnew_dev, vnew_bytes)); + + HIP_CHECK(hipMalloc(&seqlen_k_dev, batch * sizeof(int))); + HIP_CHECK(hipMemcpy(seqlen_k_dev, sk_vec.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + + if(akv_paged) + { + HIP_CHECK(hipMalloc(&akv_block_table_dev, akv_tp * sizeof(int))); + HIP_CHECK(hipMemcpy( + akv_block_table_dev, akv_bt.data(), akv_tp * sizeof(int), hipMemcpyHostToDevice)); + } + + if(has_rope) + { + const int64_t rot_bytes = static_cast(seqlen_k) * (rotary_dim / 2) * sizeof(float); + HIP_CHECK(hipMalloc(&rotary_cos_dev, rot_bytes)); + HIP_CHECK(hipMalloc(&rotary_sin_dev, rot_bytes)); + HIP_CHECK(hipMemset(rotary_cos_dev, 0, rot_bytes)); + HIP_CHECK(hipMemset(rotary_sin_dev, 0, rot_bytes)); + } + + HIP_CHECK(hipMemcpy(q_dev, q_host, q_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(knew_dev, knew_host, knew_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(vnew_dev, vnew_host, vnew_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(k_dev, 0, k_bytes)); + HIP_CHECK(hipMemset(v_dev, 0, v_bytes)); + + args.q_ptr = q_dev; + args.k_ptr = k_dev; + args.knew_ptr = knew_dev; + args.v_ptr = v_dev; + args.vnew_ptr = vnew_dev; + args.seqlen_k_ptr = seqlen_k_dev; + args.seqlen_q = seqlen_q; + args.seqlen_knew = seqlen_knew; + args.batch = batch; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.rotary_cos_ptr = rotary_cos_dev; + args.rotary_sin_ptr = rotary_sin_dev; + args.rotary_dim = rotary_dim; + args.has_mask = false; + args.block_table_ptr = akv_block_table_dev; + args.batch_stride_block_table = akv_paged ? akv_pps : 0; + args.page_block_size = akv_paged ? page_block_size : 0; + args.cache_batch_idx = nullptr; + args.sink_ptr = nullptr; + + // BHSD strides (paged K/V uses page_block_size instead of seqlen_k) + args.stride_q = hdim_q; + args.stride_k = hdim_q; + args.stride_knew = hdim_q; + args.stride_v = hdim_v; + args.stride_vnew = hdim_v; + args.nhead_stride_q = static_cast(seqlen_q) * hdim_q; + args.nhead_stride_k = static_cast(kv_s) * hdim_q; + args.nhead_stride_knew = static_cast(seqlen_knew) * hdim_q; + args.nhead_stride_v = static_cast(kv_s) * hdim_v; + args.nhead_stride_vnew = static_cast(seqlen_knew) * hdim_v; + args.batch_stride_q = static_cast(nhead_q) * seqlen_q * hdim_q; + args.batch_stride_k = static_cast(nhead_k) * kv_s * hdim_q; + args.batch_stride_knew = static_cast(nhead_k) * seqlen_knew * hdim_q; + args.batch_stride_v = static_cast(nhead_k) * kv_s * hdim_v; + args.batch_stride_vnew = static_cast(nhead_k) * seqlen_knew * hdim_v; + + try + { + auto invocation = FmhaInvocation::make(std::move(traits), std::move(args)); + if(g_registry->size() == 1) + elapsed = run_single_kernel(invocation); + else + elapsed = g_dispatcher->run_fwd_appendkv( + std::get(invocation.traits), + std::get(invocation.args), + nullptr); + } + catch(const std::exception& e) + { + fprintf(stderr, "FMHA_APPENDKV_ERR: %s\n", e.what()); + rc = -2; + goto cleanup; + } + catch(...) + { + fprintf(stderr, "FMHA_APPENDKV_ERR: unknown\n"); + rc = -2; + goto cleanup; + } + + if(time_ms_out) + *time_ms_out = elapsed; + +cleanup: + safe_hip_free(q_dev); + safe_hip_free(k_dev); + safe_hip_free(v_dev); + safe_hip_free(knew_dev); + safe_hip_free(vnew_dev); + safe_hip_free(seqlen_k_dev); + safe_hip_free(rotary_cos_dev); + safe_hip_free(rotary_sin_dev); + safe_hip_free(akv_block_table_dev); + return rc; +} + +// --------------------------------------------------------------------------- +// Batch Prefill: group-mode forward with paged KV cache +// --------------------------------------------------------------------------- +int fmha_dispatcher_run_batch_prefill(const void* q_host, + const void* k_host, + const void* v_host, + void* o_host, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + int mask_type_int, + int bias_type_int, + int page_block_size, + int kv_layout_int, + int kv_lookup_int, + int is_v_rowmajor, + const char* data_type_str, + int has_lse, + int has_dropout, + int has_logits, + int has_sink, + int skip_min_seqlen_q, + float* time_ms_out) +{ + if(!g_initialized) + return -1; + + const int in_bytes = dtype_input_bytes(data_type_str); + const int out_bytes = dtype_output_bytes(data_type_str); + + int rc = 0; + const int64_t q_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_q * in_bytes; + const int64_t o_bytes = static_cast(batch) * nhead_q * seqlen_q * hdim_v * out_bytes; + const int64_t lse_bytes = static_cast(batch) * nhead_q * seqlen_q * sizeof(float); + float elapsed = 0.0f; + + if(page_block_size <= 0) + page_block_size = 64; + const int pages_per_seq = (seqlen_k + page_block_size - 1) / page_block_size; + const int total_pages = batch * pages_per_seq; + const int64_t kv_page_bytes = static_cast(total_pages) * nhead_k * page_block_size * + std::max(hdim_q, hdim_v) * in_bytes; + + void *q_dev = nullptr, *k_dev = nullptr, *v_dev = nullptr, *o_dev = nullptr; + void *lse_dev = nullptr, *seqstart_q_dev = nullptr; + void *kv_indptr_dev = nullptr, *kv_page_indices_dev = nullptr, *kv_last_page_dev = nullptr; + void *seqlen_k_dev = nullptr, *bias_dev = nullptr, *sink_dev = nullptr; + + fmha_batch_prefill_traits traits{}; + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type_str ? data_type_str : "fp16"; + traits.is_group_mode = true; + traits.is_v_rowmajor = (is_v_rowmajor != 0); + traits.mask_type = static_cast(mask_type_int); + traits.bias_type = static_cast(bias_type_int); + traits.has_lse = (has_lse != 0); + traits.has_dropout = (has_dropout != 0); + traits.has_logits_soft_cap = (has_logits != 0); + traits.skip_min_seqlen_q = (skip_min_seqlen_q != 0); + traits.has_sink = (has_sink != 0); + traits.qscale_type = quant_scale_enum::no_scale; + traits.kv_memory_layout = + static_cast(kv_layout_int); + traits.kv_lookup_table = + static_cast(kv_lookup_int); + traits.page_size = page_block_size; + + // Declare all vectors before HIP_CHECK to avoid goto-over-init + std::vector seqstart_q(batch + 1); + for(int b = 0; b <= batch; ++b) + seqstart_q[b] = b * seqlen_q; + std::vector kv_indptr(batch + 1); + for(int b = 0; b <= batch; ++b) + kv_indptr[b] = b * pages_per_seq; + std::vector kv_page_indices(total_pages); + for(int i = 0; i < total_pages; ++i) + kv_page_indices[i] = i; + std::vector last_page(batch); + for(int b = 0; b < batch; ++b) + last_page[b] = seqlen_k - (pages_per_seq - 1) * page_block_size; + std::vector sk_vec(batch, seqlen_k); + + fmha_batch_prefill_args args{}; + + HIP_CHECK(hipMalloc(&q_dev, q_bytes)); + HIP_CHECK(hipMalloc(&k_dev, kv_page_bytes)); + HIP_CHECK(hipMalloc(&v_dev, kv_page_bytes)); + HIP_CHECK(hipMalloc(&o_dev, o_bytes)); + + HIP_CHECK(hipMalloc(&seqstart_q_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMemcpy( + seqstart_q_dev, seqstart_q.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + + HIP_CHECK(hipMalloc(&kv_indptr_dev, (batch + 1) * sizeof(int))); + HIP_CHECK(hipMemcpy( + kv_indptr_dev, kv_indptr.data(), (batch + 1) * sizeof(int), hipMemcpyHostToDevice)); + + HIP_CHECK(hipMalloc(&kv_page_indices_dev, total_pages * sizeof(int))); + HIP_CHECK(hipMemcpy(kv_page_indices_dev, + kv_page_indices.data(), + total_pages * sizeof(int), + hipMemcpyHostToDevice)); + + HIP_CHECK(hipMalloc(&kv_last_page_dev, batch * sizeof(int))); + HIP_CHECK( + hipMemcpy(kv_last_page_dev, last_page.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + + HIP_CHECK(hipMalloc(&seqlen_k_dev, batch * sizeof(int))); + HIP_CHECK(hipMemcpy(seqlen_k_dev, sk_vec.data(), batch * sizeof(int), hipMemcpyHostToDevice)); + + if(has_lse) + { + HIP_CHECK(hipMalloc(&lse_dev, lse_bytes)); + HIP_CHECK(hipMemset(lse_dev, 0, lse_bytes)); + } + if(bias_type_int > 0) + { + const int64_t bias_bytes = + (bias_type_int == 2) + ? static_cast(batch) * nhead_q * sizeof(float) + : static_cast(batch) * nhead_q * seqlen_q * seqlen_k * out_bytes; + HIP_CHECK(hipMalloc(&bias_dev, bias_bytes)); + HIP_CHECK(hipMemset(bias_dev, 0, bias_bytes)); + } + if(has_sink) + { + HIP_CHECK(hipMalloc(&sink_dev, nhead_q * sizeof(float))); + HIP_CHECK(hipMemset(sink_dev, 0, nhead_q * sizeof(float))); + } + + HIP_CHECK(hipMemcpy(q_dev, q_host, q_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(k_dev, 0, kv_page_bytes)); + HIP_CHECK(hipMemset(v_dev, 0, kv_page_bytes)); + HIP_CHECK(hipMemset(o_dev, 0, o_bytes)); + + args.q_ptr = q_dev; + args.k_ptr = k_dev; + args.v_ptr = v_dev; + args.bias_ptr = bias_dev; + args.q_descale_ptr = nullptr; + args.k_descale_ptr = nullptr; + args.v_descale_ptr = nullptr; + args.rand_val_ptr = nullptr; + args.lse_ptr = lse_dev; + args.o_ptr = o_dev; + args.seqstart_q_ptr = seqstart_q_dev; + args.sink_ptr = sink_dev; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.batch = batch; + args.max_seqlen_q = seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.num_total_pages = total_pages; + args.page_block_size = page_block_size; + args.kv_memory_layout = + static_cast(kv_layout_int); + args.kv_lookup_table = + static_cast(kv_lookup_int); + args.kv_indptr = kv_indptr_dev; + args.kv_page_indices = kv_page_indices_dev; + args.kv_last_page_lens = kv_last_page_dev; + args.seqlen_k_ptr = seqlen_k_dev; + args.batch_stride_block_table = pages_per_seq; + args.scale_s = scale; + args.scale_p = 1.0f; + args.scale_o = 1.0f; + args.logits_soft_cap = 0.0f; + + // Group-mode strides: [total_tokens, nhead, hdim] + args.stride_q = nhead_q * hdim_q; + args.stride_k = hdim_q; + args.stride_v = hdim_v; + args.stride_bias = 0; + args.stride_randval = 0; + args.stride_o = nhead_q * hdim_v; + args.nhead_stride_q = hdim_q; + args.nhead_stride_k = static_cast(page_block_size) * hdim_q; + args.nhead_stride_v = static_cast(page_block_size) * hdim_v; + args.nhead_stride_bias = 0; + args.nhead_stride_randval = 0; + args.nhead_stride_lse = seqlen_q; + args.nhead_stride_o = hdim_v; + args.batch_stride_q = 0; + args.batch_stride_k = static_cast(nhead_k) * page_block_size * hdim_q; + args.batch_stride_v = static_cast(nhead_k) * page_block_size * hdim_v; + args.batch_stride_bias = 0; + args.batch_stride_randval = 0; + args.batch_stride_lse = static_cast(nhead_q) * seqlen_q; + args.batch_stride_o = 0; + args.window_size_left = -1; + args.window_size_right = -1; + args.sink_size = 0; + args.mask_type = mask_type_int; + args.p_drop = has_dropout ? 0.2f : 0.0f; + args.s_randval = false; + args.drop_seed_offset = has_dropout ? std::make_pair(uint64_t(1), uint64_t(0)) + : std::make_pair(uint64_t(0), uint64_t(0)); + + try + { + auto invocation = FmhaInvocation::make(std::move(traits), std::move(args)); + if(g_registry->size() == 1) + elapsed = run_single_kernel(invocation); + else + elapsed = g_dispatcher->run_batch_prefill( + std::get(invocation.traits), + std::get(invocation.args), + nullptr); + } + catch(const std::exception& e) + { + fprintf(stderr, "FMHA_PREFILL_ERR: %s\n", e.what()); + rc = -2; + goto cleanup; + } + catch(...) + { + fprintf(stderr, "FMHA_PREFILL_ERR: unknown\n"); + rc = -2; + goto cleanup; + } + + { + hipError_t cpy = hipMemcpy(o_host, o_dev, o_bytes, hipMemcpyDeviceToHost); + if(cpy != hipSuccess) + rc = -1; + } + if(time_ms_out) + *time_ms_out = elapsed; + +cleanup: + safe_hip_free(q_dev); + safe_hip_free(k_dev); + safe_hip_free(v_dev); + safe_hip_free(o_dev); + safe_hip_free(lse_dev); + safe_hip_free(seqstart_q_dev); + safe_hip_free(kv_indptr_dev); + safe_hip_free(kv_page_indices_dev); + safe_hip_free(kv_last_page_dev); + safe_hip_free(seqlen_k_dev); + safe_hip_free(bias_dev); + safe_hip_free(sink_dev); + return rc; +} + +int fmha_dispatcher_kernel_count() +{ + return g_initialized ? static_cast(g_registry->size()) : 0; +} + +void fmha_dispatcher_cleanup() +{ + g_dispatcher.reset(); + g_registry.reset(); + g_initialized = false; +} + +} // extern "C" diff --git a/dispatcher/codegen/arch_filter.py b/dispatcher/codegen/arch_filter.py index 67f146045b..63dbee2dd7 100644 --- a/dispatcher/codegen/arch_filter.py +++ b/dispatcher/codegen/arch_filter.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT diff --git a/dispatcher/codegen/codegen_common.py b/dispatcher/codegen/codegen_common.py index 4e9e8de1b3..a0486da66d 100644 --- a/dispatcher/codegen/codegen_common.py +++ b/dispatcher/codegen/codegen_common.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT """ -Shared codegen infrastructure for GEMM and grouped convolution code generators. +Shared codegen infrastructure for GEMM, grouped convolution, and FMHA code generators. Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv. Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here diff --git a/dispatcher/codegen/fmha/__init__.py b/dispatcher/codegen/fmha/__init__.py new file mode 100644 index 0000000000..813f6c8af1 --- /dev/null +++ b/dispatcher/codegen/fmha/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""FMHA codegen subpackage — tile specs, instance generation, symbol mapping, and C++ codegen.""" diff --git a/dispatcher/codegen/fmha/codegen.py b/dispatcher/codegen/fmha/codegen.py new file mode 100644 index 0000000000..a063948981 --- /dev/null +++ b/dispatcher/codegen/fmha/codegen.py @@ -0,0 +1,1385 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Unified FMHA code generator for the dispatcher. + +This generator intentionally sits between the hand-maintained FMHA example codegen +and the dispatcher's runtime-registry model: + +- it consumes explicit kernel configurations or profile-filtered config lists +- it emits one header per FMHA kernel specialization +- it emits dispatcher wrapper headers that create FmhaKernelInstance objects +- it emits one .cpp translation unit per generated kernel header +""" + +import argparse +import json +import logging +import sys +from pathlib import Path +from typing import Iterable, Union + +# Ensure parent (codegen/) is on path for codegen_common and sibling modules +_CODEGEN_DIR = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(_CODEGEN_DIR)) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from codegen_common import parallel_generate # noqa: E402 +from validation import load_arch_specs, profile_allows, validate_config # noqa: E402 +from symbol_map import ( # noqa: E402 + ARCH_PREPROC_MAP, + ARCH_TAG_MAP, + BIAS_TO_CPP, + BIAS_TO_INT, + BOOL_MAP, + BWD_DTYPE_MAP, + FWD_DTYPE_MAP, + KERNEL_FAMILY_TO_ENUM, + KV_LOOKUP_TO_INT, + KV_LOOKUP_TO_CPP, + KV_MEMORY_LAYOUT_TO_CPP, + KV_MEMORY_LAYOUT_TO_INT, + LAYOUT_TO_BOOL, + MASK_TO_CPP, + MASK_TO_CPP_GENERIC, + MASK_TO_INT, + PIPELINE_ENUM_TO_CPP, + QSCALE_TO_CPP, + QSCALE_TO_INT, + ROPE_TO_CPP, + ROPE_TO_INT, + canonical_bias, + canonical_kv_lookup, + canonical_kv_memory_layout, + canonical_mask, + canonical_qscale, + canonical_rope, + kernel_name_from_config, +) + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def _bool_cpp(value) -> str: + return BOOL_MAP[bool(value)] + + +def _mask_cpp(value: str) -> str: + return MASK_TO_CPP[canonical_mask(value)] + + +def _bias_cpp(value: str) -> str: + return BIAS_TO_CPP[canonical_bias(value)] + + +def _qscale_cpp(value: str) -> str: + return QSCALE_TO_CPP[canonical_qscale(value)] + + +def _rope_cpp(value: str) -> str: + return ROPE_TO_CPP[canonical_rope(value)] + + +def _kv_memory_cpp(value: str) -> str: + return KV_MEMORY_LAYOUT_TO_CPP[canonical_kv_memory_layout(value)] + + +def _kv_lookup_cpp(value: str) -> str: + return KV_LOOKUP_TO_CPP[canonical_kv_lookup(value)] + + +def _bwd_block_tile(tile: list, sig: dict) -> str: + """Format the bwd block tile sequence. + + Source: fmha_bwd.hpp FmhaBwdDQDKDVTileSize — 9 elements: + (bm0, bn0, bk0, bn1, bk1, bk0max, tile6, tile7, tile8). + If tile has only 6 elements (forward-style), maps to BWD format using the + forward-to-backward heuristic from fmha_bwd.py. + """ + if len(tile) >= 9: + return ", ".join(str(t) for t in tile[:9]) + return ( + f"{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, " + f"{tile[3]}, {tile[5]}, {sig['hdim_q']}, {sig['hdim_v']}" + ) + + +def _canonicalize_config(raw_config: dict, target_arch: str, arch_specs: dict) -> dict: + defaults = arch_specs["defaults"] + + if "signature" not in raw_config or "algorithm" not in raw_config: + raise ValueError( + "FMHA config-json must contain 'signature' and 'algorithm' objects" + ) + + sig = dict(raw_config["signature"]) + alg = dict(raw_config["algorithm"]) + + sig.setdefault("family", "fwd") + sig.setdefault("data_type", "fp16") + sig.setdefault("mode", "batch") + sig.setdefault("vlayout", "r") + sig.setdefault("hdim_q", 128) + sig.setdefault("hdim_v", sig["hdim_q"]) + sig.setdefault("mask", "no") + sig.setdefault("bias", "no") + sig.setdefault("lse", False) + sig.setdefault("dropout", False) + sig.setdefault("qscale", "no") + sig.setdefault("rope", "none") + sig.setdefault("logits", False) + sig.setdefault("paged_kv", False) + sig.setdefault("fp8_static_quant", False) + sig.setdefault("skip_min_seqlen_q", False) + sig.setdefault("sink", False) + sig.setdefault("dbias", False) + sig.setdefault("store_randval", False) + sig.setdefault("deterministic", False) + sig.setdefault("kv_memory_layout", "vectorized") + sig.setdefault("kv_lookup_table", "sglang") + sig.setdefault("page_size", 1) + + sig["mask"] = canonical_mask(sig["mask"]) + sig["bias"] = canonical_bias(sig["bias"]) + sig["qscale"] = canonical_qscale(sig["qscale"]) + sig["rope"] = canonical_rope(sig["rope"]) + sig["kv_memory_layout"] = canonical_kv_memory_layout(sig["kv_memory_layout"]) + sig["kv_lookup_table"] = canonical_kv_lookup(sig["kv_lookup_table"]) + + alg.setdefault("pipeline", "qr") + alg.setdefault("tile", list(defaults["tile"])) + alg.setdefault("wave", list(defaults["wave"])) + alg.setdefault("warp", list(defaults["warp"])) + alg.setdefault("padding", list(defaults["padding"])) + alg.setdefault("use_trload", False) + alg.setdefault("hdim_q_alignment", sig["hdim_q"]) + alg.setdefault("hdim_v_alignment", sig["hdim_v"]) + alg.setdefault("block_per_cu", defaults["block_per_cu"]) + alg.setdefault("num_wave_groups", defaults["num_wave_groups"]) + alg.setdefault("max_splits_log2", 0) + alg.setdefault("max_seq_len_q", 0) + alg.setdefault("selection_rank", defaults["selection_rank"]) + alg.setdefault("constraint_tag", "") + + return { + "arch": raw_config.get("arch", target_arch), + "signature": sig, + "algorithm": alg, + "profile": raw_config.get("profile"), + "receipt": raw_config.get("receipt"), + } + + +def _fwd_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + dtype_cpp = FWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + vlayout_cpp = LAYOUT_TO_BOOL[sig["vlayout"]] + tile = alg["tile"] + wave = alg["wave"] + warp = alg["warp"] + pad = alg["padding"] + use_trload = _bool_cpp(alg["use_trload"]) + pipeline_name = alg["pipeline"] + pipeline_cpp = { + "qr": "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs": "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", + "v3": "ck_tile::BlockFmhaFwdV3Pipeline", + }[pipeline_name] + + ns = f"ns_{name}" + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + return f"""// SPDX-License-Identifier: MIT +// Auto-generated FMHA forward kernel specialization +#pragma once + +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_block_tile = ck_tile::sequence<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}>; + +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{warp[0]}, {warp[1]}, {warp[2]}>, + ck_tile::sequence<{wave[3]}, {wave[4]}, {wave[5]}>, + ck_tile::sequence<{warp[3]}, {warp[4]}, {warp[5]}>, + {vlayout_cpp}>; + +using fmha_traits = ck_tile::TileFmhaTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {_bool_cpp(sig["logits"])}, + {_bias_cpp(sig["bias"])}, + false, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["dropout"])}, + {_qscale_cpp(sig["qscale"])}, + {alg["block_per_cu"]}, + {_bool_cpp(sig["skip_min_seqlen_q"])}, + {_bool_cpp(sig["sink"])}>; + +using fmha_variant = ck_tile::ComposedAttention<{_bool_cpp(sig["logits"])} * ck_tile::LOGITS_SOFT_CAP, + CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_mask = {MASK_TO_CPP_GENERIC.get(canonical_mask(sig["mask"]), _mask_cpp(sig["mask"])) if pipeline_name in ("v3", "qr_async_trload_v3") else _mask_cpp(sig["mask"])}; + +using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, + {mode_cpp}, + fmha_variant, + fmha_mask, + {use_trload}, + fmha_traits>; + +using fmha_pipeline = {pipeline_cpp}; +using fmha_epilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[3])}>>; +using fmha_kernel = {"ck_tile::FmhaFwdV3Kernel" if pipeline_name in ("v3", "qr_async_trload_v3") else "ck_tile::FmhaFwdKernel"}; + +using trait = fmha_fwd_traits_<{sig["hdim_q"]}, + {dtype_cpp}, + {mode_cpp}, + {tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, + {vlayout_cpp}, + {PIPELINE_ENUM_TO_CPP[pipeline_name]}, + {_bool_cpp(sig["logits"])}, + fmha_mask, + {_bias_cpp(sig["bias"])}, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["dropout"])}, + {_qscale_cpp(sig["qscale"])}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {use_trload}, + {_bool_cpp(sig["skip_min_seqlen_q"])}, + {_bool_cpp(sig["sink"])}>; +}} // namespace {ns} + +template <> +inline float fmha_fwd_<{ns}::trait, {arch_tag}>(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = {"fmha_fwd_v3_create_kargs_and_grids" if pipeline_name in ("v3", "qr_async_trload_v3") else "fmha_fwd_create_kargs_and_grids"}(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +namespace {ns} {{ +inline float run(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + return fmha_fwd_(s, a); +}} + +inline void launch(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + auto sc = s; + sc.time_kernel_ = false; + (void)fmha_fwd_(sc, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _pagedkv_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = FWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + vlayout_cpp = LAYOUT_TO_BOOL[sig["vlayout"]] + tile = alg["tile"] + wave = alg["wave"] + warp = alg["warp"] + pad = alg["padding"] + ns = f"ns_{name}" + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_block_tile = ck_tile::sequence<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}>; +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{warp[0]}, {warp[1]}, {warp[2]}>, + ck_tile::sequence<{wave[3]}, {wave[4]}, {wave[5]}>, + ck_tile::sequence<{warp[3]}, {warp[4]}, {warp[5]}>, + {vlayout_cpp}>; + +using fmha_trait = ck_tile::TileFmhaFwdPagedKVTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {_bool_cpp(sig["logits"])}, + {_bias_cpp(sig["bias"])}, + false, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["paged_kv"])}, + {_bool_cpp(sig["fp8_static_quant"])}, + {alg["block_per_cu"]}, + {_bool_cpp(sig["skip_min_seqlen_q"])}, + {_bool_cpp(sig["sink"])}>; + +using fmha_variant = ck_tile::ComposedAttention<{_bool_cpp(sig["logits"])} * ck_tile::LOGITS_SOFT_CAP, + CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_mask = {_mask_cpp(sig["mask"])}; + +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdPagedKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, + {mode_cpp}, + fmha_variant, + fmha_mask, + fmha_trait>; + +using fmha_pipeline = ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS; +using fmha_epilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[3])}>>; +using fmha_kernel = ck_tile::FmhaFwdPagedKVKernel; + +using trait = fmha_fwd_pagedkv_traits_<{sig["hdim_q"]}, + {dtype_cpp}, + {mode_cpp}, + {tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, + {vlayout_cpp}, + {PIPELINE_ENUM_TO_CPP["qr_pagedkv"]}, + {_bool_cpp(sig["logits"])}, + fmha_mask, + {_bias_cpp(sig["bias"])}, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["paged_kv"])}, + {_bool_cpp(sig["fp8_static_quant"])}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {_bool_cpp(sig["skip_min_seqlen_q"])}, + {_bool_cpp(sig["sink"])}>; +}} // namespace {ns} + +template <> +inline float fmha_fwd_pagedkv_<{ns}::trait, {arch_tag}>(const ck_tile::stream_config& s, + fmha_fwd_pagedkv_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +namespace {ns} {{ +inline float run(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +{{ + return fmha_fwd_pagedkv_(s, a); +}} + +inline void launch(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +{{ + auto sc = s; + sc.time_kernel_ = false; + (void)fmha_fwd_pagedkv_(sc, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _splitkv_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = FWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + vlayout_cpp = LAYOUT_TO_BOOL[sig["vlayout"]] + tile = alg["tile"] + wave = alg["wave"] + warp = alg["warp"] + pad = alg["padding"] + pipeline_cpp = { + "qr": "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", + }[alg["pipeline"]] + ns = f"ns_{name}" + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_variant = ck_tile::ComposedAttention<{_bool_cpp(sig["logits"])} * ck_tile::LOGITS_SOFT_CAP, + CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_mask = {_mask_cpp(sig["mask"])}; +using fmha_block_tile = ck_tile::sequence<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}>; +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{warp[0]}, {warp[1]}, {warp[2]}>, + ck_tile::sequence<{wave[3]}, {wave[4]}, {wave[5]}>, + ck_tile::sequence<{warp[3]}, {warp[4]}, {warp[5]}>, + {vlayout_cpp}>; +using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {_bool_cpp(sig["logits"])}, + {_bias_cpp(sig["bias"])}, + false, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["fp8_static_quant"])}, + {_bool_cpp(sig["paged_kv"])}, + true, + false, + {alg["block_per_cu"]}, + {_bool_cpp(sig["sink"])}>; +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::OaccDataType, + fmha_shape, + {mode_cpp}, + fmha_variant, + fmha_mask, + fmha_trait>; +using fmha_pipeline = {pipeline_cpp}; +using fmha_epilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::OaccDataType, + typename FmhaFwdTypeConfig::OaccDataType, + false, + false>>; +using fmha_kernel = ck_tile::FmhaFwdSplitKVKernel; + +using trait = fmha_fwd_splitkv_traits_<{sig["hdim_q"]}, + {dtype_cpp}, + {mode_cpp}, + {tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, + {vlayout_cpp}, + {PIPELINE_ENUM_TO_CPP[alg["pipeline"]]}, + {_bool_cpp(sig["logits"])}, + fmha_mask, + {_bias_cpp(sig["bias"])}, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["fp8_static_quant"])}, + {_bool_cpp(sig["paged_kv"])}, + {_bool_cpp(sig["sink"])}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}>; +}} // namespace {ns} + +template <> +inline void fmha_fwd_splitkv_oneshot_<{ns}::trait, {arch_tag}>(const ck_tile::stream_config& s, + fmha_fwd_splitkv_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +namespace {ns} {{ +inline void launch(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + fmha_fwd_splitkv_oneshot_(s, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _splitkv_combine_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = FWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + tile = alg["tile"] + pad = alg["padding"] + ns = f"ns_{name}" + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +using fmha_dtype = {dtype_cpp}; +namespace {{ +template +struct {ns}_instance {{ +using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[3])}, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["fp8_static_quant"])}, + kLogMaxSplits, + {alg["block_per_cu"]}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {sig["hdim_v"]}, + {mode_cpp}, + {tile[3]}, + fmha_trait>; + +using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline; +using fmha_epilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, + false>>; +using fmha_kernel = ck_tile::FmhaFwdSplitKVCombineKernel; + +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} +}}; // struct {ns}_instance +}} // anonymous namespace + +namespace {ns} {{ +using trait = fmha_fwd_splitkv_combine_traits_<{sig["hdim_v"]}, + {dtype_cpp}, + {mode_cpp}, + {tile[3]}, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["fp8_static_quant"])}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[3])}>; +}} // namespace {ns} + +template <> +inline void fmha_fwd_splitkv_combine_oneshot_<{ns}::trait, {arch_tag}>( + const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + if (a.num_splits <= 8) {{ + {ns}_instance<3>::run(s, a); + }} else if (a.num_splits <= 16) {{ + {ns}_instance<4>::run(s, a); + }} else if (a.num_splits <= 32) {{ + {ns}_instance<5>::run(s, a); + }} else if (a.num_splits <= 64) {{ + {ns}_instance<6>::run(s, a); + }} else if (a.num_splits <= 128) {{ + {ns}_instance<7>::run(s, a); + }} +}} + +namespace {ns} {{ +inline void launch(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + fmha_fwd_splitkv_combine_oneshot_(s, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _appendkv_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = FWD_DTYPE_MAP[sig["data_type"]] + vlayout_cpp = LAYOUT_TO_BOOL[sig["vlayout"]] + tile = alg["tile"] + pad = alg["padding"] + ns = f"ns_{name}" + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_trait = ck_tile::TileFmhaFwdAppendKVTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {alg["block_per_cu"]}>; +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdAppendKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + {tile[0]}, + {tile[1]}, + {tile[2]}, + {tile[3]}, + {vlayout_cpp}, + {_rope_cpp(sig["rope"])}, + {_bool_cpp(sig["paged_kv"])}, + fmha_trait>; +using fmha_pipeline = ck_tile::BlockFmhaFwdAppendKVPipeline; +using fmha_kernel = ck_tile::FmhaFwdAppendKVKernel; + +using trait = fmha_fwd_appendkv_traits_<{sig["hdim_q"]}, + {dtype_cpp}, + {tile[0]}, + {tile[1]}, + {tile[2]}, + {tile[3]}, + {vlayout_cpp}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {_rope_cpp(sig["rope"])}, + {_bool_cpp(sig["paged_kv"])}>; +}} // namespace {ns} + +template <> +inline float fmha_fwd_appendkv_<{ns}::trait, {arch_tag}>(const ck_tile::stream_config& s, + fmha_fwd_appendkv_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +namespace {ns} {{ +inline float run(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) +{{ + return fmha_fwd_appendkv_(s, a); +}} + +inline void launch(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) +{{ + auto sc = s; + sc.time_kernel_ = false; + (void)fmha_fwd_appendkv_(sc, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _batch_prefill_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = FWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + vlayout_cpp = LAYOUT_TO_BOOL[sig["vlayout"]] + tile = alg["tile"] + wave = alg["wave"] + warp = alg["warp"] + pad = alg["padding"] + ns = f"ns_{name}" + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_block_tile = ck_tile::sequence<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}>; +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{warp[0]}, {warp[1]}, {warp[2]}>, + ck_tile::sequence<{wave[3]}, {wave[4]}, {wave[5]}>, + ck_tile::sequence<{warp[3]}, {warp[4]}, {warp[5]}>, + {vlayout_cpp}>; +using fmha_trait = ck_tile::TileFmhaBatchPrefillTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + {_bool_cpp(sig["logits"])}, + {_bias_cpp(sig["bias"])}, + false, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["dropout"])}, + {_qscale_cpp(sig["qscale"])}, + {alg["block_per_cu"]}, + false, + {sig["page_size"]}, + {_kv_memory_cpp(sig["kv_memory_layout"])}, + {_kv_lookup_cpp(sig["kv_lookup_table"])}>; +using fmha_variant = ck_tile::ComposedAttention<{_bool_cpp(sig["logits"])} * ck_tile::LOGITS_SOFT_CAP, + CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_mask = {_mask_cpp(sig["mask"])}; +using fmha_pipeline_problem = ck_tile::BlockFmhaBatchPrefillPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, + {mode_cpp}, + fmha_variant, + fmha_mask, + false, + {sig["page_size"]}, + fmha_trait>; +using fmha_pipeline = ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync; +using fmha_epilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[3])}>>; +using fmha_kernel = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; + +using trait = fmha_fwd_batch_prefill_traits_<{sig["hdim_q"]}, + {dtype_cpp}, + {mode_cpp}, + {tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, + {vlayout_cpp}, + {PIPELINE_ENUM_TO_CPP["batch_prefill_async"]}, + {_bool_cpp(sig["logits"])}, + fmha_mask, + {_bias_cpp(sig["bias"])}, + {_bool_cpp(sig["lse"])}, + {_bool_cpp(sig["dropout"])}, + {_qscale_cpp(sig["qscale"])}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[1])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(pad[3])}, + false, + false, + {sig["page_size"]}, + {_kv_memory_cpp(sig["kv_memory_layout"])}, + {_kv_lookup_cpp(sig["kv_lookup_table"])}>; +}} // namespace {ns} + +template <> +inline float fmha_batch_prefill_<{ns}::trait>(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +namespace {ns} {{ +inline float run(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + return fmha_batch_prefill_(s, a); +}} + +inline void launch(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + auto sc = s; + sc.time_kernel_ = false; + (void)fmha_batch_prefill_(sc, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _bwd_dot_do_o_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = BWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + tile = alg["tile"] + pad = alg["padding"] + ns = f"ns_{name}" + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_trait = ck_tile::TileFmhaBwdOGradDotOTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[3])}, + {alg["block_per_cu"]}>; +using fmha_pipeline_problem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + {tile[0]}, + {sig["hdim_v"]}, + {mode_cpp}, + fmha_trait>; +using fmha_pipeline = typename ck_tile::BlockFmhaBwdOGradDotO; +using fmha_kernel = ck_tile::FmhaBwdOGradDotOKernel; + +using trait = fmha_bwd_dot_do_o_traits_<{sig["hdim_v"]}, + {dtype_cpp}, + {mode_cpp}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[3])}>; +}} // namespace {ns} + +template <> +inline void fmha_bwd_dot_do_o_oneshot_<{ns}::trait, {arch_tag}>(const ck_tile::stream_config& s, + fmha_bwd_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +namespace {ns} {{ +inline void launch(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + fmha_bwd_dot_do_o_oneshot_(s, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _bwd_dq_dk_dv_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = BWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + tile = alg["tile"] + wave = alg["wave"] + warp = alg["warp"] + pad = alg["padding"] + ns = f"ns_{name}" + # BlockDropoutBwd + # wg16 variants use kIsWG32=false; wg32 variants use kIsWG32=true + dropout_variant = sig.get("dropout_variant", "") + is_wg32 = "wg32" in dropout_variant if dropout_variant else True + is_store = "storerandval" in dropout_variant if dropout_variant else False + has_dropout = bool(sig["dropout"]) + dropout_cpp = ( + f"ck_tile::BlockDropoutBwd<{_bool_cpp(has_dropout)}, " + f"{_bool_cpp(is_wg32 if has_dropout else True)}, " + f"{_bool_cpp(is_store)}>" + ) + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_block_tile = ck_tile::sequence<{_bwd_block_tile(tile, sig)}>; +using fmha_block_warps0 = ck_tile::sequence<{wave[0]}, {wave[1]}, {wave[2]}>; +using fmha_block_warps1 = ck_tile::sequence<{wave[3]}, {wave[4]}, {wave[5]}>; +using fmha_block_warps2 = ck_tile::sequence<{wave[6]}, {wave[7]}, {wave[8]}>; +using fmha_warp_tile0 = ck_tile::sequence<{warp[0]}, {warp[1]}, {warp[2]}>; +using fmha_warp_tile1 = ck_tile::sequence<{warp[3]}, {warp[4]}, {warp[5]}>; +using fmha_warp_tile2 = ck_tile::sequence<{warp[0]}, {warp[1]}, ck_tile::min({warp[2]}, {tile[6] if len(tile) >= 7 else warp[2]})>; +using fmha_shape = ck_tile::TileFmhaBwdShape; +using fmha_trait = ck_tile::TileFmhaBwdTraits<{int(pad[2])}, + {int(pad[3])}, + {_bias_cpp(sig["bias"])}, + {_bool_cpp(sig["dbias"])}, + {alg["block_per_cu"]}>; +using fmha_mask = {_mask_cpp(sig["mask"])}; +using fmha_dropout = {dropout_cpp}; +using fmha_problem = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_shape, + {mode_cpp}, + {_bool_cpp(sig["deterministic"])}, + fmha_mask, + fmha_dropout, + {_bool_cpp(alg["use_trload"])}, + fmha_trait>; +using fmha_pipeline = ck_tile::BlockFmhaBwdDQDKDVPipeline; +using dk_epi = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + ({int(pad[2])} > 0)>>; +using dv_epi = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + ({int(pad[3])} > 0)>>; +using dq_epi = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + false, + ({int(pad[2])} > 0)>>; +using fmha_kernel = ck_tile::FmhaBwdDQDKDVKernel; + +using trait = fmha_bwd_dq_dk_dv_traits_<{sig["hdim_q"]}, + {dtype_cpp}, + {mode_cpp}, + fmha_mask, + fmha_dropout, + {_bias_cpp(sig["bias"])}, + {_bool_cpp(sig["dbias"])}, + {int(pad[2])}, + {int(pad[3])}, + {_bool_cpp(sig["deterministic"])}, + {_bool_cpp(alg["use_trload"])}, + {alg["max_seq_len_q"]}, + {tile[1]}>; +}} // namespace {ns} + +template <> +inline void fmha_bwd_dq_dk_dv_oneshot_<{ns}::trait, {arch_tag}>( + const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +namespace {ns} {{ +inline void launch(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + fmha_bwd_dq_dk_dv_oneshot_(s, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def _bwd_convert_dq_kernel_body(name: str, config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + arch_tag = ARCH_TAG_MAP[config["arch"]] + arch_check = ARCH_PREPROC_MAP.get(config["arch"], "1") + dtype_cpp = BWD_DTYPE_MAP[sig["data_type"]] + mode_cpp = "true" if sig["mode"] == "group" else "false" + tile = alg["tile"] + pad = alg["padding"] + ns = f"ns_{name}" + return f"""// SPDX-License-Identifier: MIT +#pragma once + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" + +#if !defined(__HIP_DEVICE_COMPILE__) || ({arch_check}) + +namespace {ns} {{ + +using fmha_dtype = {dtype_cpp}; +using fmha_trait = ck_tile::TileFmhaBwdConvertQGradTraits<{_bool_cpp(pad[0])}, + {_bool_cpp(pad[2])}, + {alg["block_per_cu"]}>; +using fmha_problem = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + 256, + {tile[0]}, + {tile[1]}, + {sig["hdim_q"]}, + {mode_cpp}, + {_bool_cpp(sig["deterministic"])}, + fmha_trait>; +using fmha_pipeline = typename ck_tile::BlockFmhaBwdConvertQGrad; +using fmha_kernel = ck_tile::FmhaBwdConvertQGradKernel; + +using trait = fmha_bwd_convert_dq_traits_<{sig["hdim_q"]}, + {dtype_cpp}, + {mode_cpp}, + {_bool_cpp(pad[0])}, + {_bool_cpp(pad[2])}, + {_bool_cpp(sig["deterministic"])}, + {tile[1]}>; +}} // namespace {ns} + +template <> +inline void fmha_bwd_convert_dq_oneshot_<{ns}::trait, {arch_tag}>( + const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = {ns}::fmha_kernel; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +namespace {ns} {{ +inline void launch(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + fmha_bwd_convert_dq_oneshot_(s, a); +}} + +}} // namespace {ns} + +#endif // arch guard +""" + + +def render_kernel_header(name: str, config: dict) -> str: + family = config["signature"]["family"] + if family == "fwd": + return _fwd_kernel_body(name, config) + if family == "fwd_pagedkv": + return _pagedkv_kernel_body(name, config) + if family == "fwd_splitkv": + return _splitkv_kernel_body(name, config) + if family == "fwd_splitkv_combine": + return _splitkv_combine_kernel_body(name, config) + if family == "fwd_appendkv": + return _appendkv_kernel_body(name, config) + if family == "batch_prefill": + return _batch_prefill_kernel_body(name, config) + if family == "bwd_dot_do_o": + return _bwd_dot_do_o_kernel_body(name, config) + if family == "bwd_dq_dk_dv": + return _bwd_dq_dk_dv_kernel_body(name, config) + if family == "bwd_convert_dq": + return _bwd_convert_dq_kernel_body(name, config) + raise KeyError(f"Unsupported FMHA family: {family}") + + +def render_wrapper_header( + name: str, config: dict, kernel_path: Path, output_dir: Path +) -> str: + sig = config["signature"] + alg = config["algorithm"] + family = sig["family"] + rel_path = kernel_path.relative_to(output_dir) + ns = f"ns_{name}" + + if family in {"fwd", "fwd_pagedkv", "fwd_appendkv", "batch_prefill"}: + backend_factory = "make_timed_fmha_kernel" + else: + backend_factory = "make_oneshot_fmha_kernel" + + args_type_map = { + "fwd": "fmha_fwd_args", + "fwd_pagedkv": "fmha_fwd_pagedkv_args", + "fwd_splitkv": "fmha_fwd_splitkv_args", + "fwd_splitkv_combine": "fmha_fwd_splitkv_args", + "fwd_appendkv": "fmha_fwd_appendkv_args", + "batch_prefill": "fmha_batch_prefill_args", + "bwd_dot_do_o": "fmha_bwd_args", + "bwd_dq_dk_dv": "fmha_bwd_args", + "bwd_convert_dq": "fmha_bwd_args", + } + + run_symbol = "run" if backend_factory == "make_timed_fmha_kernel" else "launch" + + tile = alg["tile"] + wave = alg["wave"] + warp = alg["warp"] + pad = alg["padding"] + + return f"""// SPDX-License-Identifier: MIT +#pragma once + +// Kernel header first: includes example fmha_fwd.hpp or fmha_bwd.hpp +// which defines all necessary types (enums, args, traits). +#include "{rel_path}" +// Signal to fmha_types.hpp which types are already defined. +#define CK_TILE_FMHA_{"BWD" if family.startswith("bwd") else "FWD"}_TYPES_FROM_EXAMPLE 1 +#include "ck_tile/dispatcher/fmha_dispatcher.hpp" +#include "ck_tile/dispatcher/backends/generated_fmha_backend.hpp" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +inline FmhaKernelInstancePtr make_{name}(const std::string& gfx_arch = "{config["arch"]}") +{{ + FmhaKernelKey key; + key.signature.family = {KERNEL_FAMILY_TO_ENUM[family]}; + key.signature.data_type = "{sig["data_type"]}"; + key.signature.is_group_mode = {str(sig["mode"] == "group").lower()}; + key.signature.is_v_rowmajor = {str(sig["vlayout"] == "r").lower()}; + key.signature.has_logits_soft_cap = {str(sig["logits"]).lower()}; + key.signature.mask_type = {MASK_TO_INT[sig["mask"]]}; + key.signature.bias_type = {BIAS_TO_INT[sig["bias"]]}; + key.signature.has_lse = {str(sig["lse"]).lower()}; + key.signature.has_dropout = {str(sig["dropout"]).lower()}; + key.signature.qscale_type = {QSCALE_TO_INT[sig["qscale"]]}; + key.signature.rope_type = {ROPE_TO_INT[sig["rope"]]}; + key.signature.use_paged_kv = {str(sig["paged_kv"]).lower()}; + key.signature.do_fp8_static_quant = {str(sig["fp8_static_quant"]).lower()}; + key.signature.skip_min_seqlen_q = {str(sig["skip_min_seqlen_q"]).lower()}; + key.signature.has_sink = {str(sig["sink"]).lower()}; + key.signature.has_dbias = {str(sig["dbias"]).lower()}; + key.signature.is_store_randval = {str(sig["store_randval"]).lower()}; + key.signature.is_deterministic = {str(sig["deterministic"]).lower()}; + key.signature.kv_memory_layout = {KV_MEMORY_LAYOUT_TO_INT[sig["kv_memory_layout"]]}; + key.signature.kv_lookup_table = {KV_LOOKUP_TO_INT[sig["kv_lookup_table"]]}; + key.signature.page_size = {sig["page_size"]}; + key.signature.hdim_q = {sig["hdim_q"]}; + key.signature.hdim_v = {sig["hdim_v"]}; + + key.algorithm.tile_shape = {{{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}}}; + key.algorithm.wave_shape = {{{wave[0]}, {wave[1]}, {wave[2]}, {wave[3]}, {wave[4]}, {wave[5]}, {wave[6]}, {wave[7]}, {wave[8]}}}; + key.algorithm.warp_tile_shape = {{{warp[0]}, {warp[1]}, {warp[2]}, {warp[3]}, {warp[4]}, {warp[5]}, {warp[6]}, {warp[7]}, {warp[8]}}}; + key.algorithm.pipeline = "{alg["pipeline"]}"; + key.algorithm.pad_s = {str(pad[0]).lower()}; + key.algorithm.pad_sk = {str(pad[1]).lower()}; + key.algorithm.pad_d = {str(pad[2]).lower()}; + key.algorithm.pad_dv = {str(pad[3]).lower()}; + key.algorithm.use_trload = {str(alg["use_trload"]).lower()}; + key.algorithm.block_per_cu = {alg["block_per_cu"]}; + key.algorithm.num_wave_groups = {alg["num_wave_groups"]}; + key.algorithm.max_splits_log2 = {alg["max_splits_log2"]}; + key.algorithm.max_seq_len_q = {alg["max_seq_len_q"]}; + key.algorithm.hdim_q_alignment = {alg["hdim_q_alignment"]}; + key.algorithm.hdim_v_alignment = {alg["hdim_v_alignment"]}; + key.algorithm.selection_rank = {alg["selection_rank"]}; + key.algorithm.constraint_tag = "{alg["constraint_tag"]}"; + key.gfx_arch = gfx_arch; + + return backends::{backend_factory}<{args_type_map[family]}>(key, "{name}", {ns}::{run_symbol}); +}} + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + + +def generate_cpp_compilation_unit(name: str) -> str: + return f"""// SPDX-License-Identifier: MIT +// Auto-generated compilation unit for {name} + +#include "{name}.hpp" + +namespace ck_tile {{ namespace generated {{ +volatile bool _{name}_loaded = true; +}} }} +""" + + +class _GenItem: + def __init__(self, output_dir: Path, config: dict): + self.output_dir = output_dir + self.config = config + self.name = kernel_name_from_config(config) + + def __str__(self) -> str: + return self.name + + +def _generate_one(item: _GenItem): + name = item.name + output_dir = item.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + wrapper_dir = output_dir / "dispatcher_wrappers" + wrapper_dir.mkdir(parents=True, exist_ok=True) + + kernel_path = output_dir / f"{name}.hpp" + kernel_path.write_text(render_kernel_header(name, item.config)) + + wrapper_path = wrapper_dir / f"dispatcher_wrapper_{name}.hpp" + wrapper_path.write_text( + render_wrapper_header(name, item.config, kernel_path, output_dir) + ) + + cpp_path = output_dir / f"{name}.cpp" + cpp_path.write_text(generate_cpp_compilation_unit(name)) + + return str(kernel_path), str(wrapper_path), str(cpp_path) + + +def _iter_configs(config_blob: Union[dict, list]) -> Iterable[dict]: + if isinstance(config_blob, list): + return config_blob + if "configs" in config_blob: + return config_blob["configs"] + return [config_blob] + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Unified FMHA dispatcher code generator" + ) + parser.add_argument( + "--output", "--output-dir", dest="output_dir", type=Path, required=True + ) + parser.add_argument( + "--gpu-target", "--arch", dest="gpu_target", type=str, default="gfx942" + ) + parser.add_argument("--config-json", type=str, required=True) + parser.add_argument("--profile", type=str) + parser.add_argument("--receipt", type=str) + parser.add_argument("--no-parallel", action="store_true") + args = parser.parse_args() + + arch_specs = load_arch_specs() + raw = json.loads(args.config_json) + configs = [] + failures = [] + + for entry in _iter_configs(raw): + cfg = _canonicalize_config(entry, args.gpu_target, arch_specs) + profile_name = cfg.get("profile") or args.profile + receipt_name = cfg.get("receipt") or args.receipt + + validation = validate_config(cfg, arch_specs) + if not validation.valid: + failures.append((cfg, validation.errors)) + continue + + if not profile_allows(cfg, profile=profile_name, receipt=receipt_name): + failures.append( + ( + cfg, + [ + f"profile filter rejected config ({profile_name or receipt_name})" + ], + ) + ) + continue + + configs.append(cfg) + + if failures: + for cfg, errors in failures: + log.error( + "Rejected FMHA config %s", + cfg.get("signature", {}).get("family", "unknown"), + ) + for error in errors: + log.error(" %s", error) + if not configs: + return 1 + + items = [_GenItem(args.output_dir, config) for config in configs] + parallel_generate( + _generate_one, items, parallel=not args.no_parallel and len(items) > 1 + ) + + log.info("Generated %d FMHA kernel specialization(s)", len(items)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/dispatcher/codegen/fmha/fmha_arch_specs.json b/dispatcher/codegen/fmha/fmha_arch_specs.json new file mode 100644 index 0000000000..b0019a6a71 --- /dev/null +++ b/dispatcher/codegen/fmha/fmha_arch_specs.json @@ -0,0 +1,175 @@ +{ + "_comment": "FMHA-specific architecture specs. Edit this file to add new GPU/dtype/pipeline support for FMHA.", + "_note": "Common GPU hardware data (element_sizes, warp_size, warp_configs, lds_capacity_kb) lives in ../arch_specs.json. This file holds FMHA-specific capabilities, tile tables, and validation rules.", + + "architectures": { + "gfx90a": { + "family": "cdna2", + "arch_tag": "ck_tile::gfx9_t", + "supported_dtypes": ["fp16", "bf16", "fp32"], + "supported_pipelines": ["qr", "qr_async", "qs", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"], + "supports_trload": false, + "supports_v3": false + }, + "gfx942": { + "family": "cdna3", + "arch_tag": "ck_tile::gfx9_t", + "supported_dtypes": ["fp16", "bf16", "fp32", "fp8", "fp8fp16", "fp8bf16", "fp8fp32", "bf8"], + "supported_pipelines": ["qr", "qr_async", "qs", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"], + "supports_trload": false, + "supports_v3": false + }, + "gfx950": { + "family": "cdna4", + "arch_tag": "ck_tile::gfx9_t", + "supported_dtypes": ["fp16", "bf16", "fp32", "fp8", "fp8fp16", "fp8bf16", "fp8fp32", "bf8"], + "supported_pipelines": ["qr", "qr_async", "qs", "qr_async_trload", "qr_async_trload_v3", "v3", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"], + "supports_trload": true, + "supports_v3": true + }, + "gfx1100": { + "family": "rdna3", + "arch_tag": "ck_tile::gfx11_t", + "supported_dtypes": ["fp16", "bf16"], + "supported_pipelines": ["qr", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"], + "supports_trload": false, + "supports_v3": false + }, + "gfx1201": { + "family": "rdna4", + "arch_tag": "ck_tile::gfx12_t", + "supported_dtypes": ["fp16", "bf16", "fp8", "fp8bf16"], + "supported_pipelines": ["qr", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"], + "supports_trload": false, + "supports_v3": false + } + }, + + "supported_hdims": { + "_comment": "hdim_q must satisfy ceil_to_qualified_tile_length() in tile_fmha_shape.hpp. Each entry is [hdim_q, hdim_v].", + "fp16": [[32,32], [64,64], [80,96], [96,128], [128,128], [160,160], [192,128], [192,192], [256,256]], + "bf16": [[32,32], [64,64], [80,96], [96,128], [128,128], [160,160], [192,128], [192,192], [256,256]], + "fp32": [[32,32], [48,48], [64,64], [96,128], [128,128], [192,192], [256,256]], + "fp8": [[64,64], [128,128], [192,128], [256,256]], + "fp8bf16": [[64,64], [128,128], [192,128], [256,256]], + "fp8fp32": [[128,128]], + "bf8": [[64,64], [128,128], [192,128], [256,256]], + "mxfp8": [[128,128], [256,256]], + "mxfp4": [[128,128], [256,256]] + }, + + "fmha_warp_tiles": { + "_comment": "FMHA warp tile sizes [wm0, wn0, wk0] per FMHA dtype. Subset of MFMA/WMMA instructions relevant to attention.", + "fp16": [[32,32,16], [16,16,32]], + "bf16": [[32,32,16], [16,16,32]], + "fp32": [[16,16,16]], + "fp8": [[32,32,32]], + "fp8bf16": [[32,32,32]], + "fp8fp32": [[32,32,32]], + "bf8": [[32,32,32]], + "mxfp8": [[32,32,64]], + "mxfp4": [[16,16,128]] + }, + + "fmha_element_sizes": { + "_comment": "FMHA-specific element sizes for composite dtypes not in parent arch_specs.json. Common dtypes (fp16, bf16, fp32, fp8, bf8) use ../arch_specs.json element_sizes.", + "fp8bf16": 1, + "fp8fp32": 1, + "mxfp8": 1, + "mxfp4": 1 + }, + + "tile_sweep_ranges": { + "_comment": "Block tile dimensions to sweep. Must be multiples of warp tile sizes.", + "valid_bm0": [16, 32, 64, 128, 192, 256], + "valid_bn0": [16, 32, 64, 96, 128, 192, 256, 384], + "valid_bk0": [16, 32, 64, 96, 128, 256] + }, + + "k0max_map": { + "_comment": "Maps hdim_q -> padded K-tile length. Source: tile_fmha_shape.hpp ceil_to_qualified_tile_length().", + "32": 32, "48": 48, "64": 64, "80": 96, "96": 128, + "128": 128, "160": 256, "192": 192, "256": 256 + }, + + "lds_limits": { + "_comment": "LDS budget in bytes per non-async FMHA pipeline. Async pipelines compute LDS dynamically.", + "qr": 65536, + "qs": 65536 + }, + + "global_rules": { + "hdim_192_128_no_bias_dropout": true, + "logits_requires_no_bias": true, + "group_mode_requires_padding": true, + "hdim_divisible_by": 8 + }, + + "defaults": { + "tile": [128, 64, 32, 128, 32, 128], + "wave": [2, 2, 1, 2, 2, 1, 1, 1, 1], + "warp": [32, 32, 16, 32, 32, 16, 16, 16, 16], + "padding": [true, true, true, true], + "block_per_cu": 1, + "num_wave_groups": 1, + "selection_rank": 0 + }, + + "splitkv_combine": { + "combine_bn1": 32, + "hdims_fp16": [32, 64, 96, 128, 256], + "hdims_fp8": [64, 128, 256] + }, + + "batch_prefill": { + "supported_page_sizes": [1, 16, 1024], + "supported_kv_memory_layouts": ["vectorized", "linear"], + "supported_kv_lookup_tables": ["vllm", "sglang"] + }, + + "bwd_tiles": { + "_comment": "BWD dq_dk_dv tile tables. Format: [bm0, bn0, bk0, bn1, bk1, bk0max, tile6, tile7, tile8].", + "dq_dk_dv_fp16": { + "32_32": [32, 128, 32, 32, 32, 32, 64, 32, 32], + "64_64": [32, 128, 64, 32, 64, 32, 32, 64, 64], + "96_128": [32, 128, 96, 32, 96, 32, 32, 96, 96], + "128_128": [16, 128, 128, 16, 128, 16, 32, 128, 128], + "256_256": [16, 64, 256, 16, 256, 16, 32, 256, 256] + }, + "dq_dk_dv_extra": { + "64_64": [ + {"tile": [32, 128, 64, 32, 64, 32, 32, 64, 64], "tag": "trload", "batch_only": false}, + {"tile": [32, 16, 64, 32, 64, 32, 16, 64, 64], "tag": "small", "batch_only": true} + ], + "128_128": [ + {"tile": [16, 16, 128, 16, 128, 16, 16, 128, 128], "tag": "small", "batch_only": true}, + {"tile": [16, 192, 128, 16, 128, 16, 32, 128, 128], "tag": "bn192", "batch_only": false}, + {"tile": [32, 128, 128, 32, 128, 32, 32, 128, 128], "tag": "trload", "batch_only": false} + ] + }, + "dot_do_o_hdims": [32, 64, 96, 128, 256], + "convert_dq_hdims": [32, 64, 96, 128, 256], + "convert_dq_tile_groups": {"32": 1, "64": 1, "96": 1, "128": 1, "256": 1}, + "pad_combos": [["f","f"], ["f","t"], ["f","8"], ["t","f"], ["t","t"], ["t","8"], ["8","8"]], + "extra_pad_combos": [["f","f"], ["8","8"]], + "dropouts": ["no", "dropout_wg16", "dropout_wg16_storerandval"], + "small_dropouts": ["no"] + }, + + "bwd_wave_warp": { + "_comment": "BWD wave/warp lookup. Key: 'bm0_bn0_bk0_trload'. Value: {wave: 9-tuple, warp_k1: int}.", + "16_16_128_t": {"wave": [1,1,1,1,1,1,1,1,1], "warp_k1": 16}, + "16_64_256_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16}, + "16_128_128_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16}, + "16_192_128_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16}, + "32_16_64_t": {"wave": [1,1,1,1,1,1,1,1,1], "warp_k1": 16}, + "32_128_32_f": {"wave": [1,4,1,4,1,1,2,2,1], "warp_k1": 16}, + "32_128_64_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16}, + "32_128_64_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 32}, + "32_128_96_f": {"wave": [1,4,1,4,1,1,2,2,1], "warp_k1": 16}, + "32_128_128_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 32}, + "64_128_32_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16}, + "64_128_64_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16}, + "64_128_128_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16} + } +} diff --git a/dispatcher/codegen/fmha/generate_fallback.py b/dispatcher/codegen/fmha/generate_fallback.py new file mode 100644 index 0000000000..317938f757 --- /dev/null +++ b/dispatcher/codegen/fmha/generate_fallback.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""Generate FMHA fallback kernel + dispatch header for the Python ctypes library. + +Mirrors generate_conv_dispatch_header.py: generates a single FMHA forward +kernel and creates a dispatch header that can be force-included into +fmha_ctypes_lib.cpp. + +Usage: + python3 generate_fmha_fallback.py --output-dir /path/to/output --gpu-target gfx950 +""" + +import argparse +import json +import subprocess +import sys +from pathlib import Path + + +# Default kernel config for fallback — a single fwd fp16 kernel with +# known-good tile (128x128x32, qr_async) for basic smoke-test capability. +# Source: tile dims from fmha_fwd.py FmhaFwdTileSize for hdim=128 fp16. +DEFAULT_CONFIG = { + "arch": "gfx950", + "signature": { + "family": "fwd", + "data_type": "fp16", + "mode": "batch", + "vlayout": "r", + "hdim_q": 128, + "hdim_v": 128, + "mask": "no", + "bias": "no", + "lse": False, + "dropout": False, + "qscale": "no", + "rope": "none", + "logits": False, + "paged_kv": False, + "fp8_static_quant": False, + "skip_min_seqlen_q": False, + "sink": False, + "dbias": False, + "store_randval": False, + "deterministic": False, + "kv_memory_layout": "vectorized", + "kv_lookup_table": "sglang", + "page_size": 1, + }, + "algorithm": { + "pipeline": "qr_async", + "tile": [128, 128, 32, 128, 32, 128], + "wave": [4, 1, 1, 4, 1, 1, 1, 1, 1], + "warp": [32, 32, 16, 32, 32, 16, 16, 16, 16], + "padding": [True, True, True, True], + "block_per_cu": 1, + "num_wave_groups": 1, + "max_splits_log2": 0, + "max_seq_len_q": 0, + }, +} + + +def generate_dispatch_header(output_dir: Path, wrapper_dir: Path) -> Path: + """Generate fmha_python_dispatch.hpp from the wrapper headers.""" + wrappers = sorted(wrapper_dir.glob("dispatcher_wrapper_fmha_*.hpp")) + if not wrappers: + raise RuntimeError(f"No FMHA dispatcher wrappers found in {wrapper_dir}") + + kernel_names = [] + make_calls = [] + for w in wrappers: + stem = w.stem.replace("dispatcher_wrapper_", "") + kernel_names.append(stem) + make_calls.append( + f" registry.register_kernel(" + f"ck_tile::dispatcher::generated::make_{stem}(arch));" + ) + + lines = [ + "// Auto-generated FMHA dispatch header for Python ctypes library", + "#pragma once", + "", + ] + for w in wrappers: + lines.append(f'#include "dispatcher_wrappers/{w.name}"') + + lines += [ + "", + '#include "ck_tile/dispatcher/fmha_registry.hpp"', + '#include "ck_tile/dispatcher/fmha_dispatcher.hpp"', + "", + "namespace generated {", + "", + "inline void register_fmha_python_kernels(" + "ck_tile::dispatcher::FmhaRegistry& registry, const std::string& arch) {", + " (void)arch;", + ] + lines += make_calls + lines += [ + "}", + "", + "} // namespace generated", + "", + "#ifndef REGISTER_GENERATED_KERNELS", + "#define REGISTER_GENERATED_KERNELS(registry, arch) " + "::generated::register_fmha_python_kernels(registry, arch)", + "#endif", + "", + "// Stable C ABI for dlopen/dlsym-based kernel registration.", + '// Plugins call dlsym(handle, "ck_fmha_register_kernels") to get this.', + 'extern "C" __attribute__((visibility("default")))', + "int ck_fmha_register_kernels(", + " ck_tile::dispatcher::FmhaRegistry& registry, const char* arch) {", + " ::generated::register_fmha_python_kernels(registry, arch ? std::string(arch) : std::string());", + f" return {len(kernel_names)};", + "}", + "", + "// Kernel inventory for Python introspection", + f"static const int FMHA_KERNEL_COUNT = {len(kernel_names)};", + "static const char* FMHA_KERNEL_NAMES[] = {" + + ", ".join(f'"{n}"' for n in kernel_names) + + "};", + "", + ] + + header_path = output_dir / "fmha_python_dispatch.hpp" + header_path.write_text("\n".join(lines) + "\n") + return header_path + + +def compile_kernels(output_dir: Path, gpu_target: str, include_dirs: str) -> Path: + """Compile kernel .cpp files into a static library.""" + import shutil + + hipcc = shutil.which("hipcc") or "/opt/rocm/bin/hipcc" + kernel_cpps = sorted(output_dir.glob("fmha_*.cpp")) + if not kernel_cpps: + raise RuntimeError("No kernel .cpp files to compile") + + import re + + # Use the shared compile flags from fmha_utils + sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "python")) + from fmha_utils import fmha_compile_flags # noqa: E402 + + base_flags = fmha_compile_flags(gpu_target, hipcc, family="bwd") + + inc_flags = [] + for d in re.split(r"[;:]", include_dirs): + d = d.strip() + if d: + inc_flags.extend(["-I", d]) + + objs = [] + for cpp in kernel_cpps: + obj = cpp.with_suffix(".o") + cmd = base_flags + inc_flags + [str(cpp), "-o", str(obj)] + print(f" Compiling {cpp.name}...") + r = subprocess.run(cmd, capture_output=True, text=True) + if r.returncode != 0: + print(f" FAILED: {r.stderr}", file=sys.stderr) + raise RuntimeError(f"Failed to compile {cpp.name}") + objs.append(str(obj)) + + lib_path = output_dir / "libfmha_python_fallback.a" + ar_cmd = ["ar", "rcs", str(lib_path)] + objs + subprocess.check_call(ar_cmd) + print(f" Static lib: {lib_path}") + return lib_path + + +def main(): + parser = argparse.ArgumentParser( + description="Generate FMHA fallback kernel for Python library" + ) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--gpu-target", default="gfx950") + parser.add_argument( + "--config-json", + default=None, + help="Override default kernel config (JSON string)", + ) + parser.add_argument( + "--compile", action="store_true", help="Also compile the kernel .cpp into a .a" + ) + parser.add_argument( + "--include-dirs", + default="", + help="Semicolon-separated include directories for compilation", + ) + args = parser.parse_args() + + output_dir = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + + codegen_dir = Path(__file__).parent + codegen_script = codegen_dir / "codegen.py" + + # Accept either a single config dict or a list of configs + if args.config_json: + parsed = json.loads(args.config_json) + if isinstance(parsed, list): + # Multi-config: pass list directly to unified_fmha_codegen + codegen_input = parsed + else: + # Single config: merge with defaults + config = dict(DEFAULT_CONFIG) + config["arch"] = args.gpu_target + config["signature"] = dict(DEFAULT_CONFIG["signature"]) + config["algorithm"] = dict(DEFAULT_CONFIG["algorithm"]) + config.update(parsed) + codegen_input = config + else: + config = dict(DEFAULT_CONFIG) + config["arch"] = args.gpu_target + config["signature"] = dict(DEFAULT_CONFIG["signature"]) + config["algorithm"] = dict(DEFAULT_CONFIG["algorithm"]) + codegen_input = config + + print(f"Generating FMHA fallback kernel for {args.gpu_target}...") + print(f" Output: {output_dir}") + + cmd = [ + sys.executable, + str(codegen_script), + "--output-dir", + str(output_dir), + "--gpu-target", + args.gpu_target, + "--config-json", + json.dumps(codegen_input), + ] + + result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(codegen_dir)) + if result.returncode != 0: + print(f" Codegen FAILED:\n{result.stderr}", file=sys.stderr) + return 1 + + wrapper_dir = output_dir / "dispatcher_wrappers" + if not wrapper_dir.exists(): + print(" ERROR: No dispatcher_wrappers dir created", file=sys.stderr) + return 1 + + header_path = generate_dispatch_header(output_dir, wrapper_dir) + print(f" Dispatch header: {header_path}") + + kernel_cpps = list(output_dir.glob("fmha_*.cpp")) + print(f" Kernel TUs: {len(kernel_cpps)}") + + if args.compile and kernel_cpps: + compile_kernels(output_dir, args.gpu_target, args.include_dirs) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/codegen/fmha/instance_gen.py b/dispatcher/codegen/fmha/instance_gen.py new file mode 100644 index 0000000000..20536cabdf --- /dev/null +++ b/dispatcher/codegen/fmha/instance_gen.py @@ -0,0 +1,2692 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +FMHA instance generation — generates tile configs and expands kernel instances. + +Three layers: + 1. Tile generation — enumerate valid (bm0, bn0, bk0, warp) combinations + 2. Feature enumeration — enumerate valid (mask, bias, lse, dropout, padding) combinations + 3. Instance expansion — cross-product tiles × features × modes → kernel configs + +All hardware facts and constraints come from specs.py. +All symbol mappings come from symbol_map.py. + +Usage: + python -m fmha.instance_gen configs/receipt0_fwd.json --arch gfx950 + python -m fmha.instance_gen configs/fwd_ci.json --arch gfx950 --list +""" + +import argparse +import itertools +import json +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple + +_THIS_DIR = Path(__file__).resolve().parent +_DISPATCHER_ROOT = _THIS_DIR.parents[1] +sys.path.insert(0, str(_DISPATCHER_ROOT / "python")) +sys.path.insert(0, str(_THIS_DIR)) + +from validation import ( # noqa: E402 + ARCH_DTYPES, + BIASES, + BOOLS, + BWD_CONVERT_DQ_HDIMS, + BWD_CONVERT_DQ_TILE_GROUPS, + BWD_DOT_DO_O_HDIMS, + BWD_DQ_DK_DV_EXTRA_TILES, + BWD_DQ_DK_DV_TILES_FP16, + BWD_DQ_WAVE_WARP, + BWD_DROPOUTS, + BWD_EXTRA_PAD_COMBOS, + BWD_PAD_COMBOS, + BWD_SMALL_DROPOUTS, + DT_FP16_BF16, + DT_FP32, + DT_FP8, + DT_FP8FP32, + ELEMENT_SIZES, + K0_MAX_SUBMAX_MAP, + LDS_LIMITS, + MASKS, + SPLITKV_COMBINE_HDIMS_FP16, + SPLITKV_COMBINE_HDIMS_FP8, + SUPPORTED_HDIMS, + VALID_BK0, + VALID_BM0, + VALID_BN0, + WARP_CLASSES, + check_gfx9_tile_constraints, + check_gfx950_tile_constraints, + check_group_mode_padding, + check_logits_bias, + check_qr_mfma_insts, + receipt_filter, + tile_passes_all_constraints, +) +from fmha_utils import FmhaKernelConfig # noqa: E402 (from dispatcher/python/) + + +# ============================================================================= +# Tile configuration dataclass +# ============================================================================= + + +@dataclass(frozen=True) +class FmhaTileConfig: + """Complete FMHA tile configuration with all derived parameters. + + Field naming follows CK's TileFmhaShape template parameters: + - bm0/bn0/bk0: block tile for Gemm0 (Q*K^T), from sequence + - bn1/bk1: block tile for Gemm1 (P*V) + - bk0max: kSubQKHeaddim from tile_fmha_shape.hpp + - rm0: wave repeat in M direction = bm0/wm0 + - wm0/wn0/wk0: MFMA/WMMA warp tile from warp_gemm_dispatcher.hpp + """ + + bm0: int + bn0: int + bk0: int + bn1: int # = hdim_v + bk1: int # = 32 typically + bk0max: int # = K0_MAX_SUBMAX_MAP[hdim_q] + rm0: int # wave repeat = bm0/wm0 + wm0: int + wn0: int + wk0: int + wm1: int + wn1: int + wk1: int + rn0: int = 1 + rk0: int = 1 + rm1: int = 1 + rn1: int = 1 + rk1: int = 1 + + @property + def tile_6(self) -> Tuple[int, int, int, int, int, int]: + return (self.bm0, self.bn0, self.bk0, self.bn1, self.bk1, self.bk0max) + + +# ============================================================================= +# BK1 derivation +# ============================================================================= + + +def derive_bk1(bm0: int, bn0: int, bk0: int, hdim_q: int, hdim_v: int) -> int: + """Derive bk1 from tile config for fp16/bf16/fp32. + + Source: fmha_fwd.py FmhaFwdTileSize definitions — bk1 (element 4) is + always 32 except for three specific configs where it's 16. + These special cases come from the CK example's hand-tuned tile tables. + """ + if (bm0, bn0, bk0, hdim_q) in ( + (128, 64, 32, 128), + (32, 128, 32, 128), + (32, 128, 16, 48), + ): + return 16 + return 32 + + +def derive_bk1_fp8(bm0: int, bn0: int, bk0: int, hdim_q: int, hdim_v: int) -> int: + """Derive bk1 for fp8 dtypes. + + Source: fmha_fwd.py FP8 tile definitions — bk1 always equals bk0. + """ + return bk0 + + +# ============================================================================= +# Tile generation +# ============================================================================= + + +def generate_fwd_tiles( + arch: str, + dtype: str, + hdim_q: int, + hdim_v: int, + pipeline: str = "qr_async", + apply_constraints: bool = True, +) -> List[FmhaTileConfig]: + """Generate fwd tile configurations. + + apply_constraints=True (default): filter through tile_passes_all_constraints + — used by rules-mode benchmarking and codegen. + apply_constraints=False: only basic sanity (warp alignment, bk0<=hdim_q) + — used by exhaustive-mode benchmarking to find tiles the C++ compiler + accepts that our rules might reject. + """ + warp_classes = WARP_CLASSES.get(dtype, [(32, 32, 16)]) + bk0max = K0_MAX_SUBMAX_MAP.get(hdim_q, hdim_q) + is_fp8 = "fp8" in dtype or dtype in ("bf8", "mxfp8", "mxfp4") + + tiles: List[FmhaTileConfig] = [] + for bm0 in VALID_BM0: + for bn0 in VALID_BN0: + for bk0 in VALID_BK0: + if bk0 > hdim_q: + continue + for wm0, wn0, wk0 in warp_classes: + if bm0 % wm0 != 0 or bn0 % wn0 != 0 or bk0 % wk0 != 0: + continue + if apply_constraints and not tile_passes_all_constraints( + arch, + dtype, + hdim_q, + hdim_v, + pipeline, + bm0, + bn0, + bk0, + wm0, + wn0, + wk0, + ): + continue + + rm0 = bm0 // wm0 + bk1 = ( + derive_bk1_fp8(bm0, bn0, bk0, hdim_q, hdim_v) + if is_fp8 + else derive_bk1(bm0, bn0, bk0, hdim_q, hdim_v) + ) + + tiles.append( + FmhaTileConfig( + bm0=bm0, + bn0=bn0, + bk0=bk0, + bn1=hdim_v, + bk1=bk1, + bk0max=bk0max, + rm0=rm0, + rm1=rm0, + wm0=wm0, + wn0=wn0, + wk0=wk0, + wm1=wm0, + wn1=wn0, + wk1=wk0, + ) + ) + + return tiles + + +def generate_splitkv_tiles( + arch: str, + dtype: str, + hdim_q: int, + hdim_v: int, + apply_constraints: bool = True, +) -> List[FmhaTileConfig]: + """Generate splitkv tiles. + + Uses fixed warp class per dtype: (16,16,16) for fp16/bf16/fp32, + (32,32,32) for fp8. These match the warp tiles used in the CK example's + splitkv tile definitions (fmha_fwd.py KernelComponentFactory*.get_splitkv_tiles()). + LDS limit: 64 KiB (non-async pipeline, arch.hpp get_smem_capacity for non-gfx950). + + apply_constraints=False skips LDS check (for exhaustive mode). + """ + bk0max = K0_MAX_SUBMAX_MAP.get(hdim_q, hdim_q) + is_fp8 = "fp8" in dtype or dtype == "bf8" + wm0, wn0, wk0 = (32, 32, 32) if is_fp8 else (16, 16, 16) + + tiles: List[FmhaTileConfig] = [] + for bm0 in VALID_BM0: + for bn0 in VALID_BN0: + for bk0 in VALID_BK0: + if bk0 > hdim_q: + continue + if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0: + continue + if apply_constraints: + elem_size = ELEMENT_SIZES.get(dtype, 2) + lds_limit = LDS_LIMITS.get("qr", 65536) + if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit: + continue + + rm0 = bm0 // wm0 + bk1 = bk0 if is_fp8 else 32 + + tiles.append( + FmhaTileConfig( + bm0=bm0, + bn0=bn0, + bk0=bk0, + bn1=hdim_v, + bk1=bk1, + bk0max=bk0max, + rm0=rm0, + rm1=rm0, + wm0=wm0, + wn0=wn0, + wk0=wk0, + wm1=wm0, + wn1=wn0, + wk1=wk0, + ) + ) + + return tiles + + +def generate_pagedkv_tiles( + arch: str, + dtype: str, + hdim_q: int, + hdim_v: int, + apply_constraints: bool = True, +) -> List[FmhaTileConfig]: + """PagedKV uses same tile rules as splitkv.""" + return generate_splitkv_tiles(arch, dtype, hdim_q, hdim_v, apply_constraints) + + +def generate_bwd_tiles( + arch: str, + dtype: str, + hdim_q: int, + hdim_v: int, + apply_constraints: bool = True, +) -> List[FmhaTileConfig]: + """Generate BWD tile configurations. + + apply_constraints=False skips LDS check (for exhaustive mode). + """ + warp_classes = WARP_CLASSES.get(dtype, [(32, 32, 16)]) + bk0max = K0_MAX_SUBMAX_MAP.get(hdim_q, hdim_q) + is_fp8 = "fp8" in dtype or dtype in ("bf8", "mxfp8", "mxfp4") + + tiles: List[FmhaTileConfig] = [] + for bm0 in VALID_BM0: + for bn0 in VALID_BN0: + for bk0 in VALID_BK0: + if bk0 > hdim_q: + continue + + for wm0, wn0, wk0 in warp_classes: + if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0: + continue + if apply_constraints: + elem_size = ELEMENT_SIZES.get(dtype, 2) + lds_limit = LDS_LIMITS.get("qs", 65536) + if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit: + continue + + rm0 = bm0 // wm0 + bk1 = bk0 if is_fp8 else 32 + + tiles.append( + FmhaTileConfig( + bm0=bm0, + bn0=bn0, + bk0=bk0, + bn1=hdim_v, + bk1=bk1, + bk0max=bk0max, + rm0=rm0, + rm1=rm0, + wm0=wm0, + wn0=wn0, + wk0=wk0, + wm1=wm0, + wn1=wn0, + wk1=wk0, + ) + ) + + return tiles + + +def validate_tile( + tile: "FmhaTileConfig", + arch: str, + dtype: str, + hdim_q: int, + hdim_v: int, + pipeline: str = "qr_async", +) -> bool: + """Validate a single tile configuration against all constraints.""" + return tile_passes_all_constraints( + arch, + dtype, + hdim_q, + hdim_v, + pipeline, + tile.bm0, + tile.bn0, + tile.bk0, + tile.wm0, + tile.wn0, + tile.wk0, + ) + + +# ============================================================================= +# Pipeline spec dataclasses +# ============================================================================= + + +@dataclass(frozen=True) +class PipelineSpec: + """One FWD pipeline variant with its feature flags and padding.""" + + tag: str + mask: str + bias: str + lse: str + dropout: str + logits: str + skip: str + sink: str + qscale: str = "no" + spad: str = "f" + skpad: str = "f" + dpad: str = "f" + dvpad: str = "f" + + +@dataclass(frozen=True) +class SplitKVPipelineSpec: + """Split-KV main kernel pipeline variant.""" + + tag: str + mask: str + bias: str + logits: str + sink: str + pagedkv: str = "f" + squant: str = "f" + spad: str = "f" + skpad: str = "f" + dpad: str = "f" + dvpad: str = "f" + lse: str = "t" + + +@dataclass(frozen=True) +class SplitKVCombineSpec: + """Split-KV combine kernel pipeline variant.""" + + spad: str + dvpad: str + lse: str + squant: str = "f" + + +@dataclass(frozen=True) +class AppendKVPipelineSpec: + """Append-KV pipeline variant.""" + + rope: str = "none" + pagedkv: str = "f" + spad: str = "t" + skpad: str = "t" + dpad: str = "t" + dvpad: str = "t" + + +@dataclass(frozen=True) +class BatchPrefillPipelineSpec: + """Batch prefill pipeline variant.""" + + mask: str + bias: str + logits: str + sink: str + lse: str = "f" + dropout: str = "f" + skip: str = "f" + qscale: str = "no" + page_size: int = 0 + kv_memory_layout: str = "vectorized" + kv_lookup_table: str = "sglang" + spad: str = "t" + skpad: str = "t" + dpad: str = "t" + dvpad: str = "t" + + +@dataclass(frozen=True) +class BwdPipelineSpec: + """BWD pipeline variant.""" + + family: str + mask: str = "no" + bias: str = "no" + dbias: str = "f" + dropout: str = "f" + deterministic: str = "f" + spad: str = "t" + skpad: str = "t" + dpad: str = "t" + dvpad: str = "t" + + +# ============================================================================= +# Feature-product generators +# ============================================================================= + + +def _fwd_specs_fp16bf16( + hdim: int, + hdim_v: int, + receipt: int, +) -> List[PipelineSpec]: + """Pipeline specs for fp16/bf16 on gfx9/gfx950. + + Source: fmha_fwd.py KernelComponentFactoryGfx9.get_pipelines() — + hdim=256 always uses 'qr' (non-async, since bk0 can equal 256). + Non-256 hdims use 'qr_async' for non-bias configs (async DMA), + 'qr' for bias configs (bias requires Q in LDS). + Receipt=1 (ck_extended) adds extra 'qr' variants for non-bias. + """ + specs: List[PipelineSpec] = [] + + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + BOOLS, + MASKS, + BIASES, + BOOLS, + BOOLS, + BOOLS, + BOOLS, + ): + if hdim == 256 and hdim_v == 256: + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="f", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="t", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + else: + if bias == "bias": + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="f", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + else: + specs.append( + PipelineSpec( + "qr_async", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="t", + skpad="f", + dpad="t", + dvpad="t", + ) + ) + specs.append( + PipelineSpec( + "qr_async", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + if receipt == 1 and bias != "bias": + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + + return specs + + +def _fwd_specs_gfx950_extra(hdim: int, hdim_v: int) -> List[PipelineSpec]: + """Additional trload/v3 pipelines for gfx950 fp16/bf16. + + Source: fmha_fwd.py CompatibilityRuleFactoryGfx950 — + qr_async_trload only supports hdims (64,64) and (128,128), + requires no logits/bias/dropout/skip. + qr_async_trload_v3 only supports (128,128), no/causal mask only. + """ + specs: List[PipelineSpec] = [] + + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + BOOLS, + MASKS, + BIASES, + BOOLS, + BOOLS, + BOOLS, + BOOLS, + ): + if ( + (hdim, hdim_v) in [(64, 64), (128, 128)] + and logits == "f" + and bias == "no" + and dropout == "f" + and skip == "f" + ): + specs.append( + PipelineSpec( + "qr_async_trload", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="f", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr_async_trload", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="f", + skpad="f", + dpad="t", + dvpad="t", + ) + ) + + if (hdim, hdim_v) == (128, 128): + for logits, mask in itertools.product(BOOLS, ["no", "causal"]): + specs.append( + PipelineSpec( + "qr_async_trload_v3", + mask, + "no", + "f", + "f", + logits, + "f", + "f", + spad="t", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + + return specs + + +def _fwd_specs_fp8(hdim: int, hdim_v: int) -> List[PipelineSpec]: + """Pipeline specs for fp8bf16/fp8fp32. + + Source: fmha_fwd.py KernelComponentFactoryGfx9._DT_FP8 pipelines — + hdim=64 uses 'qr' (non-async), others use 'qr_async'. + FP8 supports pertensor and blockscale quantization (qscale). + No lse, dropout, skip, or bias for fp8. + """ + specs: List[PipelineSpec] = [] + + for logits, qscale, mask, bias, sink in itertools.product( + BOOLS, + ["no", "pertensor", "blockscale"], + MASKS, + ["no"], + BOOLS, + ): + tag = "qr" if hdim == 64 else "qr_async" + specs.append( + PipelineSpec( + tag, + mask, + bias, + "f", + "f", + logits, + "f", + sink, + qscale=qscale, + spad="t", + skpad="f", + dpad="t", + dvpad="t", + ) + ) + specs.append( + PipelineSpec( + tag, + mask, + bias, + "f", + "f", + logits, + "f", + sink, + qscale=qscale, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + + return specs + + +def _fwd_specs_fp32(hdim: int, hdim_v: int) -> List[PipelineSpec]: + """Pipeline specs for fp32. + + Source: fmha_fwd.py KernelComponentFactoryGfx9._DT_FP32 — + always uses 'qr' pipeline (no async for fp32). + Full feature set (mask, bias, lse, dropout, logits, etc.). + """ + specs: List[PipelineSpec] = [] + + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + BOOLS, + MASKS, + BIASES, + BOOLS, + BOOLS, + BOOLS, + BOOLS, + ): + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="f", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="f", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr", + mask, + bias, + lse, + dropout, + logits, + skip, + sink, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + + return specs + + +def get_pipelines_for_config( + arch: str, + dtype: str, + hdim: int, + hdim_v: int, + receipt: int = 0, +) -> List[PipelineSpec]: + """Get all valid pipeline specs for a given (arch, dtype, hdim, hdim_v, receipt).""" + if dtype in DT_FP32: + specs = _fwd_specs_fp32(hdim, hdim_v) + elif dtype in DT_FP16_BF16: + specs = _fwd_specs_fp16bf16(hdim, hdim_v, receipt) + if arch == "gfx950": + specs.extend(_fwd_specs_gfx950_extra(hdim, hdim_v)) + elif dtype in DT_FP8 or dtype in DT_FP8FP32: + specs = _fwd_specs_fp8(hdim, hdim_v) + else: + return [] + + return [ + s + for s in specs + if check_logits_bias(s.logits, s.bias) and receipt_filter(receipt, dtype, s) + ] + + +# --- SplitKV --- + + +def get_splitkv_pipelines( + dtype: str, hdim: int, receipt: int = 0 +) -> List[SplitKVPipelineSpec]: + """Split-KV main kernel pipelines.""" + specs: List[SplitKVPipelineSpec] = [] + SPLITKV_MASKS = ["no", "causal"] + + if dtype in DT_FP16_BF16: + for logits, mask, bias, pagedkv, sink in itertools.product( + BOOLS, SPLITKV_MASKS, BIASES, BOOLS, BOOLS + ): + if not check_logits_bias(logits, bias): + continue + specs.append( + SplitKVPipelineSpec( + "qr", + mask, + bias, + logits, + sink, + pagedkv, + spad="f", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + specs.append( + SplitKVPipelineSpec( + "qr", + mask, + bias, + logits, + sink, + pagedkv, + spad="t", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + SplitKVPipelineSpec( + "qr", + mask, + bias, + logits, + sink, + pagedkv, + spad="t", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + specs.append( + SplitKVPipelineSpec( + "qr", + mask, + bias, + logits, + sink, + pagedkv, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + elif dtype in ("fp8", "bf8"): + for logits, mask, bias in itertools.product(BOOLS, SPLITKV_MASKS, BIASES): + if not check_logits_bias(logits, bias): + continue + specs.append( + SplitKVPipelineSpec( + "qr", + mask, + bias, + logits, + "f", + "f", + squant="t", + spad="f", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + SplitKVPipelineSpec( + "qr", + mask, + bias, + logits, + "f", + "f", + squant="t", + spad="t", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + + if receipt != 0: + specs = [s for s in specs if _splitkv_receipt_filter(receipt, dtype, s)] + return specs + + +def _splitkv_receipt_filter( + receipt: int, dtype: str, spec: SplitKVPipelineSpec +) -> bool: + if receipt == 2: + return ( + dtype in ("fp16", "bf16") + and spec.bias in ("no", "alibi") + and spec.squant == "f" + and spec.sink == "f" + ) + if receipt == 4: + return ( + dtype in ("fp16", "bf16") + and spec.bias in ("no", "bias") + and spec.squant == "f" + and spec.sink == "f" + ) + if receipt == 200: + return dtype in ("fp16", "bf16") and spec.squant == "f" + if receipt == 600: + return dtype in ("fp16", "bf16") and spec.squant == "f" + if receipt in (800, 801): + return dtype == "fp32" + return True + + +def get_splitkv_combine_pipelines( + dtype: str, receipt: int = 0 +) -> List[SplitKVCombineSpec]: + """Split-KV combine kernel pipelines.""" + specs: List[SplitKVCombineSpec] = [] + squant = "t" if dtype in ("fp8", "bf8") else "f" + + if dtype in DT_FP16_BF16: + for spad, dvpad, lse in itertools.product(BOOLS, BOOLS, BOOLS): + specs.append(SplitKVCombineSpec(spad, dvpad, lse, squant)) + elif dtype in ("fp8", "bf8"): + for spad, dvpad in itertools.product(BOOLS, BOOLS): + specs.append(SplitKVCombineSpec(spad, dvpad, "f", squant)) + return specs + + +# --- PagedKV --- + + +def get_pagedkv_pipelines( + dtype: str, hdim: int, receipt: int = 0 +) -> List[PipelineSpec]: + """PagedKV prefill pipelines.""" + specs: List[PipelineSpec] = [] + + if dtype in DT_FP16_BF16: + for logits, mask, bias, sink in itertools.product(BOOLS, MASKS, BIASES, BOOLS): + if not check_logits_bias(logits, bias): + continue + specs.append( + PipelineSpec( + "qr_pagedkv", + mask, + bias, + "f", + "f", + logits, + "f", + sink, + spad="t", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr_pagedkv", + mask, + bias, + "f", + "f", + logits, + "f", + sink, + spad="t", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + elif dtype in ("fp8", "bf8"): + for logits, mask, bias in itertools.product(BOOLS, MASKS, BIASES): + if not check_logits_bias(logits, bias): + continue + specs.append( + PipelineSpec( + "qr_pagedkv", + mask, + bias, + "f", + "f", + logits, + "f", + "f", + spad="f", + skpad="f", + dpad="f", + dvpad="f", + ) + ) + specs.append( + PipelineSpec( + "qr_pagedkv", + mask, + bias, + "f", + "f", + logits, + "f", + "f", + spad="t", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + + if receipt != 0: + specs = [s for s in specs if receipt_filter(receipt, dtype, s)] + return specs + + +# --- AppendKV --- + + +def get_appendkv_pipelines( + dtype: str, hdim: int, receipt: int = 0 +) -> List[AppendKVPipelineSpec]: + """Append-KV pipelines.""" + specs: List[AppendKVPipelineSpec] = [] + + if dtype in DT_FP16_BF16: + for pagedkv in ["t", "f"]: + specs.append( + AppendKVPipelineSpec( + rope="none", + pagedkv=pagedkv, + spad="f", + skpad="t", + dpad="f", + dvpad="f", + ) + ) + specs.append( + AppendKVPipelineSpec( + rope="none", + pagedkv=pagedkv, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + specs.append( + AppendKVPipelineSpec( + rope="interleaved", + pagedkv=pagedkv, + spad="f", + skpad="t", + dpad="t", + dvpad="f", + ) + ) + specs.append( + AppendKVPipelineSpec( + rope="interleaved", + pagedkv=pagedkv, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + specs.append( + AppendKVPipelineSpec( + rope="half_rotated", + pagedkv=pagedkv, + spad="f", + skpad="t", + dpad="t", + dvpad="f", + ) + ) + specs.append( + AppendKVPipelineSpec( + rope="half_rotated", + pagedkv=pagedkv, + spad="t", + skpad="t", + dpad="t", + dvpad="t", + ) + ) + elif dtype in ("fp8", "bf8"): + specs.append( + AppendKVPipelineSpec( + rope="none", pagedkv="f", spad="t", skpad="t", dpad="t", dvpad="t" + ) + ) + return specs + + +# --- Batch Prefill --- + + +def get_batch_prefill_pipelines( + dtype: str, hdim: int, receipt: int = 0 +) -> List[BatchPrefillPipelineSpec]: + """Batch prefill pipelines.""" + specs: List[BatchPrefillPipelineSpec] = [] + PREFILL_MASKS = ["no", "causal"] + + if dtype in DT_FP16_BF16: + for logits, mask, bias, lse, dropout, kvl, kvt in itertools.product( + BOOLS, + PREFILL_MASKS, + BIASES, + BOOLS, + BOOLS, + ["vectorized", "linear"], + ["vllm", "sglang"], + ): + if not check_logits_bias(logits, bias): + continue + specs.append( + BatchPrefillPipelineSpec( + mask, + bias, + logits, + "f", + lse, + dropout, + "f", + page_size=0, + kv_memory_layout=kvl, + kv_lookup_table=kvt, + ) + ) + elif dtype == "fp8bf16": + for logits, qscale, mask, bias, kvl, kvt in itertools.product( + BOOLS, + ["pertensor", "kv_blockscale"], + MASKS, + ["no"], + ["vectorized", "linear"], + ["vllm", "sglang"], + ): + if not check_logits_bias(logits, bias): + continue + specs.append( + BatchPrefillPipelineSpec( + mask, + bias, + logits, + "f", + "f", + "f", + "f", + qscale=qscale, + page_size=0, + kv_memory_layout=kvl, + kv_lookup_table=kvt, + ) + ) + return specs + + +# --- BWD --- + + +def get_bwd_dq_dk_dv_pipelines(dtype: str, receipt: int = 0) -> List[BwdPipelineSpec]: + """BWD dq_dk_dv feature product.""" + if dtype not in DT_FP16_BF16: + return [] + specs: List[BwdPipelineSpec] = [] + for mask, bias, dbias, dropout, deterministic in itertools.product( + MASKS, + BIASES, + BOOLS, + BWD_DROPOUTS, + BOOLS, + ): + if bias != "bias" and dbias == "t": + continue + for dpad, dvpad in BWD_PAD_COMBOS: + specs.append( + BwdPipelineSpec( + "bwd_dq_dk_dv", + mask, + bias, + dbias, + dropout, + deterministic, + dpad=dpad, + dvpad=dvpad, + ) + ) + return specs + + +def get_bwd_dq_dk_dv_extra_pipelines( + dtype: str, is_small: bool = False, receipt: int = 0 +) -> List[BwdPipelineSpec]: + """BWD dq_dk_dv extra tile pipelines (reduced feature set).""" + if dtype not in DT_FP16_BF16: + return [] + specs: List[BwdPipelineSpec] = [] + dropouts = BWD_SMALL_DROPOUTS if is_small else BWD_DROPOUTS + for mask, bias, dbias, dropout, deterministic in itertools.product( + MASKS, + BIASES, + BOOLS, + dropouts, + BOOLS, + ): + if bias != "bias" and dbias == "t": + continue + for dpad, dvpad in BWD_EXTRA_PAD_COMBOS: + specs.append( + BwdPipelineSpec( + "bwd_dq_dk_dv", + mask, + bias, + dbias, + dropout, + deterministic, + dpad=dpad, + dvpad=dvpad, + ) + ) + return specs + + +def get_bwd_dot_do_o_pipelines(dtype: str) -> List[BwdPipelineSpec]: + """BWD dot_do_o: spad x dvpad variants.""" + if dtype not in DT_FP16_BF16: + return [] + return [ + BwdPipelineSpec("bwd_dot_do_o", spad=s, dvpad=d) + for s, d in itertools.product(BOOLS, BOOLS) + ] + + +def get_bwd_convert_dq_pipelines(dtype: str, hdim: int = 0) -> List[BwdPipelineSpec]: + """BWD convert_dq: spad x deterministic x dpad.""" + if dtype not in DT_FP16_BF16: + return [] + dpads = ["f", "t", "8"] if hdim == 128 else BOOLS + return [ + BwdPipelineSpec("bwd_convert_dq", spad=s, deterministic=d, dpad=dp) + for s, d, dp in itertools.product(BOOLS, BOOLS, dpads) + ] + + +# ============================================================================= +# Tile compatibility (used by expand to double-check) +# ============================================================================= + + +def tile_compatible( + arch: str, + dtype: str, + hdim: int, + hdim_v: int, + pipeline_tag: str, + tile: Tuple[int, ...], +) -> bool: + """Check if a tile tuple passes arch-specific constraints (subset of tile_passes_all_constraints).""" + + bm0, bn0, bk0 = tile[0], tile[1], tile[2] + + if not check_gfx9_tile_constraints( + dtype, hdim, hdim_v, pipeline_tag, bm0, bn0, bk0 + ): + return False + if arch == "gfx950": + if not check_gfx950_tile_constraints(hdim, hdim_v, pipeline_tag, bm0, bn0): + return False + # Use default warp for mfma check + wn0, wk0 = 32, 16 + warp_classes = WARP_CLASSES.get(dtype, [(32, 32, 16)]) + if warp_classes: + _, wn0, wk0 = warp_classes[0] + if not check_qr_mfma_insts(arch, hdim, pipeline_tag, bn0, bk0, wn0, wk0): + return False + return True + + +# ============================================================================= +# BWD wave/warp lookup +# ============================================================================= + + +def bwd_dq_wave_warp(tile, hq, trload=False): + """Look up BWD wave/warp config for a tile.""" + trl = "t" if trload else "f" + key = (tile[0], tile[1], tile[2], trl) + entry = BWD_DQ_WAVE_WARP.get(key) + if entry is None: + for k, v in BWD_DQ_WAVE_WARP.items(): + if k[:3] == (tile[0], tile[1], tile[2]): + entry = v + break + if entry is None: + bn0 = tile[1] + wn = min(4, max(1, bn0 // 32)) + return { + "wave_m0": 1, + "wave_n0": wn, + "wave_k0": 1, + "wave_m1": 4, + "wave_n1": 1, + "wave_k1": 1, + "wave_m2": 1, + "wave_n2": wn, + "wave_k2": 1, + "warp_m0": 16, + "warp_n0": 16, + "warp_k0": 32, + "warp_m1": 16, + "warp_n1": 16, + "warp_k1": 16, + "warp_m2": 16, + "warp_n2": 16, + "warp_k2": 16, + } + w = entry["wave"] + wk1 = entry["warp_k1"] + return { + "wave_m0": w[0], + "wave_n0": w[1], + "wave_k0": w[2], + "wave_m1": w[3], + "wave_n1": w[4], + "wave_k1": w[5], + "wave_m2": w[6], + "wave_n2": w[7], + "wave_k2": w[8], + "warp_m0": 16, + "warp_n0": 16, + "warp_k0": 32, + "warp_m1": 16, + "warp_n1": 16, + "warp_k1": wk1, + "warp_m2": 16, + "warp_n2": 16, + "warp_k2": 16, + } + + +# ============================================================================= +# Instance expansion +# ============================================================================= + +VARIANT_TO_FAMILY = { + "fwd": "fwd", + "bwd": "bwd_dq_dk_dv", + "splitkv": "fwd_splitkv", + "appendkv": "fwd_appendkv", + "pagedkv": "fwd_pagedkv", + "batch_prefill": "batch_prefill", +} + +MODES = ["batch", "group"] + +_MASK_MAP = {"no": "no", "causal": "top_left", "generic": "generic"} +_BIAS_MAP = {"no": "no", "bias": "bias", "alibi": "alibi"} + + +def _pad_val(s: str) -> int: + if s == "f": + return 0 + if s == "t": + return 1 + return int(s) + + +def expand_sweep( + config_path: Optional[str], + arch: str, + receipt: int = 0, + mode: str = "rules", + restrict_hdims: Optional[List[Tuple[int, int]]] = None, + default_variant: str = "fwd", +) -> List[FmhaKernelConfig]: + """Expand sweep into full kernel instance list. + + Args: + config_path: Path to JSON sweep config, or None for defaults + (only valid with mode="exhaustive"). + arch: Target GPU arch ("gfx950" etc.). + receipt: Receipt level (0 = full, higher = filtered). + mode: "rules" applies tile_passes_all_constraints + receipt-driven + pipeline×feature coupling. "exhaustive" skips constraints and uses + a raw cartesian feature product (variant must be "fwd"). + restrict_hdims: If set, only generate configs for these (hq, hv) pairs. + default_variant: Variant to use when config_path is None. + """ + if config_path is None: + if mode != "exhaustive": + raise ValueError("config_path is required for mode='rules'") + config = {"variant": default_variant, "trait_config": {}} + else: + with open(config_path) as f: + config = json.load(f) + + variant = config.get("variant", default_variant) + + # Build allow-list filters from JSON trait_config + trait_cfg = config.get("trait_config", {}) + + def _allow(key: str) -> Optional[Set[str]]: + entry = trait_cfg.get(key) + if entry is None: + return None + return set(entry.get("values", [])) + + allowed_dtypes = _allow("data_type") + allowed_pipes = _allow("pipeline") + allowed_masks = _allow("mask") + allowed_biases = _allow("bias") + allowed_modes = _allow("mode") + allowed_lse = _allow("lse") + allowed_dropout = _allow("dropout") + allowed_logits = _allow("logits") + allowed_sink = _allow("sink") + allowed_paged_kv = _allow("paged_kv") + + # block_per_cu: int or list of ints to sweep + bpc_entry = trait_cfg.get("block_per_cu", {}) + block_per_cu_values = bpc_entry.get("values", [-1]) + if isinstance(block_per_cu_values, int): + block_per_cu_values = [block_per_cu_values] + + # Intersect with arch support + arch_dtypes = set(ARCH_DTYPES.get(arch, ARCH_DTYPES.get("gfx950", []))) + dtypes = ( + sorted(allowed_dtypes & arch_dtypes) if allowed_dtypes else sorted(arch_dtypes) + ) + + configs: List[FmhaKernelConfig] = [] + + if mode == "exhaustive": + if variant == "fwd": + configs = _expand_fwd_exhaustive( + arch, + dtypes, + allowed_pipes, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_lse, + allowed_dropout, + allowed_logits, + allowed_sink, + block_per_cu_values, + restrict_hdims, + ) + elif variant == "splitkv": + configs = _expand_splitkv_exhaustive( + arch, + dtypes, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_logits, + allowed_sink, + allowed_paged_kv, + restrict_hdims, + ) + elif variant == "pagedkv": + configs = _expand_pagedkv_exhaustive( + arch, + dtypes, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims, + ) + elif variant == "bwd": + configs = _expand_bwd_exhaustive( + arch, + dtypes, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims, + ) + elif variant in ("appendkv", "batch_prefill"): + # These have fixed tiles (no tile sweep), so exhaustive = rules mode + if variant == "appendkv": + configs = _expand_appendkv( + arch, dtypes, 0, restrict_hdims=restrict_hdims + ) + else: + configs = _expand_batch_prefill( + arch, + dtypes, + 0, + allowed_masks, + allowed_biases, + restrict_hdims=restrict_hdims, + ) + else: + raise ValueError(f"Exhaustive mode not supported for variant {variant!r}") + elif variant == "fwd": + configs = _expand_fwd( + arch, + dtypes, + receipt, + allowed_pipes, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_lse, + allowed_dropout, + allowed_logits, + allowed_sink, + block_per_cu_values, + restrict_hdims=restrict_hdims, + ) + elif variant == "splitkv": + configs = _expand_splitkv( + arch, + dtypes, + receipt, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_logits, + allowed_sink, + allowed_paged_kv, + restrict_hdims=restrict_hdims, + ) + elif variant == "pagedkv": + configs = _expand_pagedkv( + arch, + dtypes, + receipt, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims=restrict_hdims, + ) + elif variant == "appendkv": + configs = _expand_appendkv(arch, dtypes, receipt, restrict_hdims=restrict_hdims) + elif variant == "batch_prefill": + configs = _expand_batch_prefill( + arch, + dtypes, + receipt, + allowed_masks, + allowed_biases, + restrict_hdims=restrict_hdims, + ) + elif variant == "bwd": + configs = _expand_bwd( + arch, + dtypes, + receipt, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims=restrict_hdims, + ) + + # Dedup + seen: set = set() + unique: List[FmhaKernelConfig] = [] + for c in configs: + if c.name not in seen: + seen.add(c.name) + unique.append(c) + return unique + + +def _build_fwd_kernel_config( + *, + arch: str, + dtype: str, + mode: str, + hq: int, + hv: int, + pipeline: str, + tc: FmhaTileConfig, + pad_s: int = 0, + pad_sk: int = 0, + pad_d: int = 0, + pad_dv: int = 0, + mask: str = "no", + bias: str = "no", + lse: bool = False, + dropout: bool = False, + logits: bool = False, + sink: bool = False, + skip_min_seqlen_q: bool = False, + qscale: str = "no", + block_per_cu: int = -1, +) -> FmhaKernelConfig: + """Single source of truth for fwd FmhaKernelConfig kwargs derived from a tile.""" + return FmhaKernelConfig( + family="fwd", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline=pipeline, + tile_m0=tc.bm0, + tile_n0=tc.bn0, + tile_k0=tc.bk0, + tile_n1=tc.bn1, + tile_k1=tc.bk1, + tile_k0max=tc.bk0max, + wave_m0=tc.rm0, + wave_n0=1, + wave_k0=1, + wave_m1=tc.rm0, + wave_n1=1, + wave_k1=1, + warp_m0=tc.wm0, + warp_n0=tc.wn0, + warp_k0=tc.wk0, + warp_m1=tc.wm1, + warp_n1=tc.wn1, + warp_k1=tc.wk1, + pad_s=pad_s, + pad_sk=pad_sk, + pad_d=pad_d, + pad_dv=pad_dv, + mask=mask, + bias=bias, + lse=lse, + dropout=dropout, + logits=logits, + sink=sink, + skip_min_seqlen_q=skip_min_seqlen_q, + qscale=qscale, + block_per_cu=block_per_cu, + gfx_arch=arch, + ) + + +def _expand_fwd( + arch, + dtypes, + receipt, + allowed_pipes, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_lse, + allowed_dropout, + allowed_logits, + allowed_sink, + block_per_cu_values=None, + restrict_hdims=None, +): + if block_per_cu_values is None: + block_per_cu_values = [-1] + configs = [] + for dtype in dtypes: + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + pipeline_specs = get_pipelines_for_config(arch, dtype, hq, hv, receipt) + _tile_cache: Dict[str, List[FmhaTileConfig]] = {} + for mode in MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in pipeline_specs: + if not check_group_mode_padding(mode, spec.spad, spec.skpad): + continue + if allowed_pipes is not None and spec.tag not in allowed_pipes: + continue + mm = _MASK_MAP.get(spec.mask, spec.mask) + mb = _BIAS_MAP.get(spec.bias, spec.bias) + lv = spec.lse == "t" + dv = spec.dropout == "t" + lgv = spec.logits == "t" + sv = spec.sink == "t" + skv = spec.skip == "t" + if allowed_masks is not None and mm not in allowed_masks: + continue + if allowed_biases is not None and mb not in allowed_biases: + continue + if allowed_lse is not None and lv not in allowed_lse: + continue + if allowed_dropout is not None and dv not in allowed_dropout: + continue + if allowed_logits is not None and lgv not in allowed_logits: + continue + if allowed_sink is not None and sv not in allowed_sink: + continue + if spec.tag not in _tile_cache: + _tile_cache[spec.tag] = generate_fwd_tiles( + arch, dtype, hq, hv, spec.tag + ) + for tc in _tile_cache[spec.tag]: + t6 = (tc.bm0, tc.bn0, tc.bk0, tc.bn1, tc.bk1, tc.bk0max) + if not tile_compatible(arch, dtype, hq, hv, spec.tag, t6): + continue + for bpc in block_per_cu_values: + configs.append( + _build_fwd_kernel_config( + arch=arch, + dtype=dtype, + mode=mode, + hq=hq, + hv=hv, + pipeline=spec.tag, + tc=tc, + pad_s=_pad_val(spec.spad), + pad_sk=_pad_val(spec.skpad), + pad_d=_pad_val(spec.dpad), + pad_dv=_pad_val(spec.dvpad), + mask=mm, + bias=mb, + lse=lv, + dropout=dv, + logits=lgv, + sink=sv, + skip_min_seqlen_q=skv, + qscale=spec.qscale, + block_per_cu=bpc, + ) + ) + return configs + + +def _expand_fwd_exhaustive( + arch, + dtypes, + allowed_pipes, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_lse, + allowed_dropout, + allowed_logits, + allowed_sink, + block_per_cu_values, + restrict_hdims, +): + """Exhaustive fwd expansion: ALL tiles (no constraint filter) × full feature cross-product. + + Differs from _expand_fwd in two ways: + 1. Tiles come from generate_fwd_tiles(..., apply_constraints=False) + 2. Features are a raw cartesian product (no pipeline-receipt coupling) + + Used by --tiles=exhaustive in the benchmark to discover compilable tiles + that the rules engine rejects. + """ + pipelines = ( + sorted(allowed_pipes) + if allowed_pipes + else ["qr", "qr_async", "qr_async_trload", "qr_async_trload_v3"] + ) + modes = sorted(allowed_modes) if allowed_modes else MODES + masks = ( + sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"] + ) + biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"] + lse_vals = sorted(allowed_lse) if allowed_lse else [False, True] + dropout_vals = sorted(allowed_dropout) if allowed_dropout else [False, True] + logits_vals = sorted(allowed_logits) if allowed_logits else [False, True] + sink_vals = sorted(allowed_sink) if allowed_sink else [False] + bpc_vals = block_per_cu_values if block_per_cu_values else [-1, 1, 2] + + configs: List[FmhaKernelConfig] = [] + for dtype in dtypes: + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + for pipeline in pipelines: + tiles = generate_fwd_tiles( + arch, dtype, hq, hv, pipeline, apply_constraints=False + ) + for tc in tiles: + for mode, mask, bias, lv, dv, lgv, sv, bpc in itertools.product( + modes, + masks, + biases, + lse_vals, + dropout_vals, + logits_vals, + sink_vals, + bpc_vals, + ): + configs.append( + _build_fwd_kernel_config( + arch=arch, + dtype=dtype, + mode=mode, + hq=hq, + hv=hv, + pipeline=pipeline, + tc=tc, + mask=mask, + bias=bias, + lse=lv, + dropout=dv, + logits=lgv, + sink=sv, + block_per_cu=bpc, + ) + ) + return configs + + +def _expand_splitkv_exhaustive( + arch, + dtypes, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_logits, + allowed_sink, + allowed_paged_kv, + restrict_hdims, +): + """Exhaustive splitkv: ALL tiles (no LDS filter) × full feature product.""" + modes = sorted(allowed_modes) if allowed_modes else MODES + masks = ( + sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"] + ) + biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"] + logits_vals = sorted(allowed_logits) if allowed_logits else [False, True] + sink_vals = sorted(allowed_sink) if allowed_sink else [False] + + configs: List[FmhaKernelConfig] = [] + for dtype in dtypes: + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + tiles = generate_splitkv_tiles(arch, dtype, hq, hv, apply_constraints=False) + for tc in tiles: + for mode, mask, bias, lgv, sv in itertools.product( + modes, + masks, + biases, + logits_vals, + sink_vals, + ): + configs.append( + FmhaKernelConfig( + family="fwd_splitkv", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline="qr", + tile_m0=tc.bm0, + tile_n0=tc.bn0, + tile_k0=tc.bk0, + tile_n1=tc.bn1, + tile_k1=tc.bk1, + tile_k0max=tc.bk0max, + wave_m0=tc.rm0, + wave_n0=1, + wave_k0=1, + wave_m1=tc.rm0, + wave_n1=1, + wave_k1=1, + warp_m0=tc.wm0, + warp_n0=tc.wn0, + warp_k0=tc.wk0, + warp_m1=tc.wm1, + warp_n1=tc.wn1, + warp_k1=tc.wk1, + mask=mask, + bias=bias, + lse=True, + logits=lgv, + sink=sv, + gfx_arch=arch, + ) + ) + return configs + + +def _expand_pagedkv_exhaustive( + arch, + dtypes, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims, +): + """Exhaustive pagedkv: ALL tiles (no LDS filter) × full feature product.""" + modes = sorted(allowed_modes) if allowed_modes else MODES + masks = ( + sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"] + ) + biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"] + + configs: List[FmhaKernelConfig] = [] + for dtype in dtypes: + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + tiles = generate_pagedkv_tiles(arch, dtype, hq, hv, apply_constraints=False) + for tc in tiles: + for mode, mask, bias in itertools.product(modes, masks, biases): + configs.append( + FmhaKernelConfig( + family="fwd_pagedkv", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline="qr_pagedkv", + tile_m0=tc.bm0, + tile_n0=tc.bn0, + tile_k0=tc.bk0, + tile_n1=tc.bn1, + tile_k1=tc.bk1, + tile_k0max=tc.bk0max, + wave_m0=tc.rm0, + wave_n0=1, + wave_k0=1, + wave_m1=tc.rm0, + wave_n1=1, + wave_k1=1, + warp_m0=tc.wm0, + warp_n0=tc.wn0, + warp_k0=tc.wk0, + warp_m1=tc.wm1, + warp_n1=tc.wn1, + warp_k1=tc.wk1, + mask=mask, + bias=bias, + paged_kv=True, + gfx_arch=arch, + ) + ) + return configs + + +def _expand_bwd_exhaustive( + arch, + dtypes, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims, +): + """Exhaustive bwd: ALL tiles (no LDS filter) × full feature product. + + Note: BWD uses spec-defined fixed tiles for dq_dk_dv, but we can still + exhaust the dot_do_o and convert_dq with unfiltered tile generation. + For dq_dk_dv we use generate_bwd_tiles(apply_constraints=False) since + CK's bwd tile tables are hand-curated and the exhaustive sweep should + explore beyond them. + """ + modes = sorted(allowed_modes) if allowed_modes else MODES + masks = ( + sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"] + ) + biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"] + deterministic_vals = [False, True] + dropout_vals = ["no", "p", "rp"] + + configs: List[FmhaKernelConfig] = [] + for dtype in dtypes: + if dtype not in ("fp16", "bf16"): + continue + + # dot_do_o — fixed tile, just sweep features + dot_specs = get_bwd_dot_do_o_pipelines(dtype) + for hd in BWD_DOT_DO_O_HDIMS: + if restrict_hdims is not None and (hd, hd) not in restrict_hdims: + continue + for mode in modes: + for spec in dot_specs: + configs.append( + FmhaKernelConfig( + family="bwd_dot_do_o", + data_type=dtype, + mode=mode, + hdim_q=hd, + hdim_v=hd, + pipeline="qr", + tile_m0=64, + pad_s=_pad_val(spec.spad), + pad_dv=_pad_val(spec.dvpad), + gfx_arch=arch, + ) + ) + + # dq_dk_dv — exhaustive tiles + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + tiles = generate_bwd_tiles(arch, dtype, hq, hv, apply_constraints=False) + for tc in tiles: + for mode, mask, bias, dropout, det in itertools.product( + modes, + masks, + biases, + dropout_vals, + deterministic_vals, + ): + ww = bwd_dq_wave_warp((tc.bm0, tc.bn0, tc.bk0), hq) + configs.append( + FmhaKernelConfig( + family="bwd_dq_dk_dv", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline="qr", + tile_m0=tc.bm0, + tile_n0=tc.bn0, + tile_k0=tc.bk0, + tile_n1=tc.bn1, + tile_k1=tc.bk1, + tile_k0max=tc.bk0max, + mask=mask, + bias=bias, + dropout=(dropout != "no"), + dropout_variant=dropout, + deterministic=det, + gfx_arch=arch, + **ww, + ) + ) + + # convert_dq — no tile sweep (fixed tile), just feature sweep + for hd in BWD_CONVERT_DQ_HDIMS: + if restrict_hdims is not None and (hd, hd) not in restrict_hdims: + continue + for mode, det in itertools.product(modes, deterministic_vals): + configs.append( + FmhaKernelConfig( + family="bwd_convert_dq", + data_type=dtype, + mode=mode, + hdim_q=hd, + hdim_v=hd, + pipeline="qr", + tile_m0=64, + deterministic=det, + gfx_arch=arch, + ) + ) + return configs + + +def _expand_splitkv( + arch, + dtypes, + receipt, + allowed_masks, + allowed_biases, + allowed_modes, + allowed_logits=None, + allowed_sink=None, + allowed_paged_kv=None, + restrict_hdims=None, +): + configs = [] + for dtype in dtypes: + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + tiles = generate_splitkv_tiles(arch, dtype, hq, hv) + sk_specs = get_splitkv_pipelines(dtype, hq, receipt) + for tc in tiles: + for mode in MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in sk_specs: + if mode == "group" and not ( + spec.spad == "t" and spec.skpad == "t" + ): + continue + mm = _MASK_MAP.get(spec.mask, spec.mask) + mb = _BIAS_MAP.get(spec.bias, spec.bias) + if allowed_masks is not None and mm not in allowed_masks: + continue + if allowed_biases is not None and mb not in allowed_biases: + continue + lgv = spec.logits == "t" + sv = spec.sink == "t" + pkv = spec.pagedkv == "t" + if allowed_logits is not None and lgv not in allowed_logits: + continue + if allowed_sink is not None and sv not in allowed_sink: + continue + if allowed_paged_kv is not None and pkv not in allowed_paged_kv: + continue + configs.append( + FmhaKernelConfig( + family="fwd_splitkv", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline=spec.tag, + tile_m0=tc.bm0, + tile_n0=tc.bn0, + tile_k0=tc.bk0, + tile_n1=tc.bn1, + tile_k1=tc.bk1, + tile_k0max=tc.bk0max, + wave_m0=tc.rm0, + wave_n0=1, + wave_k0=1, + wave_m1=tc.rm0, + wave_n1=1, + wave_k1=1, + warp_m0=tc.wm0, + warp_n0=tc.wn0, + warp_k0=tc.wk0, + warp_m1=tc.wm1, + warp_n1=tc.wn1, + warp_k1=tc.wk1, + pad_s=_pad_val(spec.spad), + pad_sk=_pad_val(spec.skpad), + pad_d=_pad_val(spec.dpad), + pad_dv=_pad_val(spec.dvpad), + mask=mm, + bias=mb, + lse=True, + logits=lgv, + sink=sv, + paged_kv=pkv, + gfx_arch=arch, + ) + ) + # Combine kernels + for dtype in dtypes: + comb_specs = get_splitkv_combine_pipelines(dtype, receipt) + if not comb_specs: + continue + hdims = ( + SPLITKV_COMBINE_HDIMS_FP16 + if dtype in ("fp16", "bf16") + else SPLITKV_COMBINE_HDIMS_FP8 + if dtype in ("fp8", "bf8") + else [] + ) + for hv in hdims: + for mode in MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in comb_specs: + if mode == "group" and spec.spad != "t": + continue + configs.append( + FmhaKernelConfig( + family="fwd_splitkv_combine", + data_type=dtype, + mode=mode, + hdim_q=hv, + hdim_v=hv, + pipeline="splitkv_combine", + tile_m0=32, + tile_n0=hv, + tile_k0=32, + tile_n1=32, + pad_s=_pad_val(spec.spad), + pad_dv=_pad_val(spec.dvpad), + lse=(spec.lse == "t"), + gfx_arch=arch, + ) + ) + return configs + + +def _expand_pagedkv( + arch, + dtypes, + receipt, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims=None, +): + configs = [] + for dtype in dtypes: + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + tiles = generate_pagedkv_tiles(arch, dtype, hq, hv) + pk_specs = get_pagedkv_pipelines(dtype, hq, receipt) + for tc in tiles: + for mode in MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in pk_specs: + if mode == "group" and not ( + spec.spad == "t" and spec.skpad == "t" + ): + continue + mm = _MASK_MAP.get(spec.mask, spec.mask) + mb = _BIAS_MAP.get(spec.bias, spec.bias) + if allowed_masks is not None and mm not in allowed_masks: + continue + if allowed_biases is not None and mb not in allowed_biases: + continue + configs.append( + FmhaKernelConfig( + family="fwd_pagedkv", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline=spec.tag, + tile_m0=tc.bm0, + tile_n0=tc.bn0, + tile_k0=tc.bk0, + tile_n1=tc.bn1, + tile_k1=tc.bk1, + tile_k0max=tc.bk0max, + wave_m0=tc.rm0, + wave_n0=1, + wave_k0=1, + wave_m1=tc.rm0, + wave_n1=1, + wave_k1=1, + warp_m0=tc.wm0, + warp_n0=tc.wn0, + warp_k0=tc.wk0, + warp_m1=tc.wm1, + warp_n1=tc.wn1, + warp_k1=tc.wk1, + pad_s=_pad_val(spec.spad), + pad_sk=_pad_val(spec.skpad), + pad_d=_pad_val(spec.dpad), + pad_dv=_pad_val(spec.dvpad), + mask=mm, + bias=mb, + logits=(spec.logits == "t"), + skip_min_seqlen_q=(spec.skip == "t"), + sink=(spec.sink == "t"), + paged_kv=True, + gfx_arch=arch, + ) + ) + return configs + + +def _expand_appendkv(arch, dtypes, receipt, restrict_hdims=None): + configs = [] + for dtype in dtypes: + ak_specs = get_appendkv_pipelines(dtype, 0, receipt) + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + for spec in ak_specs: + configs.append( + FmhaKernelConfig( + family="fwd_appendkv", + data_type=dtype, + mode="batch", + hdim_q=hq, + hdim_v=hv, + pipeline="appendkv", + tile_m0=64, + tile_n0=64, + tile_k0=hq, + tile_n1=hv, + pad_s=_pad_val(spec.spad), + pad_sk=_pad_val(spec.skpad), + pad_d=_pad_val(spec.dpad), + pad_dv=_pad_val(spec.dvpad), + rope={ + "none": "none", + "interleaved": "interleaved", + "half_rotated": "half_rotated", + }.get(spec.rope, spec.rope), + paged_kv=(spec.pagedkv == "t"), + gfx_arch=arch, + ) + ) + return configs + + +def _expand_batch_prefill( + arch, dtypes, receipt, allowed_masks, allowed_biases, restrict_hdims=None +): + configs = [] + page_sizes = [1, 16, 1024] + + def _bp_bk1(bm0, bn0, bk0, hq): + if bm0 == 64 and bn0 == 128 and bk0 == 64 and hq == 128: + return 64 + return 32 + + for dtype in dtypes: + hdims = SUPPORTED_HDIMS.get(dtype, []) + if restrict_hdims is not None: + hdims = [hv for hv in hdims if hv in restrict_hdims] + for hq, hv in hdims: + tiles = generate_splitkv_tiles(arch, dtype, hq, hv) + bp_specs = get_batch_prefill_pipelines(dtype, hq, receipt) + for tc in tiles: + bk1 = _bp_bk1(tc.bm0, tc.bn0, tc.bk0, hq) + for spec in bp_specs: + mm = _MASK_MAP.get(spec.mask, spec.mask) + mb = _BIAS_MAP.get(spec.bias, spec.bias) + if allowed_masks is not None and mm not in allowed_masks: + continue + if allowed_biases is not None and mb not in allowed_biases: + continue + for ps in page_sizes: + if ps == 1 and spec.kv_memory_layout != "linear": + continue + if spec.qscale == "kv_blockscale" and ps < tc.bn0: + continue + configs.append( + FmhaKernelConfig( + family="batch_prefill", + data_type=dtype, + mode="group", + hdim_q=hq, + hdim_v=hv, + pipeline="qr_async", + tile_m0=tc.bm0, + tile_n0=tc.bn0, + tile_k0=tc.bk0, + tile_n1=tc.bn1, + tile_k1=bk1, + tile_k0max=tc.bk0max, + wave_m0=tc.rm0, + wave_n0=1, + wave_k0=1, + wave_m1=tc.rm0, + wave_n1=1, + wave_k1=1, + warp_m0=tc.wm0, + warp_n0=tc.wn0, + warp_k0=tc.wk0, + warp_m1=tc.wm1, + warp_n1=tc.wn1, + warp_k1=tc.wk1, + pad_s=1, + pad_sk=1, + pad_d=1, + pad_dv=1, + mask=mm, + bias=mb, + lse=(spec.lse == "t"), + dropout=(spec.dropout == "t"), + logits=(spec.logits == "t"), + paged_kv=True, + page_size=ps, + kv_memory_layout=spec.kv_memory_layout, + kv_lookup_table=spec.kv_lookup_table, + qscale=spec.qscale, + gfx_arch=arch, + ) + ) + return configs + + +def _expand_bwd( + arch, + dtypes, + receipt, + allowed_masks, + allowed_biases, + allowed_modes, + restrict_hdims=None, +): + configs = [] + for dtype in dtypes: + if dtype not in ("fp16", "bf16"): + continue + + # dot_do_o + dot_specs = get_bwd_dot_do_o_pipelines(dtype) + for hd in BWD_DOT_DO_O_HDIMS: + for mode in MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in dot_specs: + if mode == "group" and spec.spad != "t": + continue + configs.append( + FmhaKernelConfig( + family="bwd_dot_do_o", + data_type=dtype, + mode=mode, + hdim_q=hd, + hdim_v=hd, + pipeline="qr", + tile_m0=64, + pad_s=_pad_val(spec.spad), + pad_dv=_pad_val(spec.dvpad), + gfx_arch=arch, + ) + ) + + # dq_dk_dv: main tiles + dq_specs = get_bwd_dq_dk_dv_pipelines(dtype, receipt) + for (hq, hv), tile in sorted(BWD_DQ_DK_DV_TILES_FP16.items()): + for mode in MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in dq_specs: + mm = _MASK_MAP.get(spec.mask, spec.mask) + mb = _BIAS_MAP.get(spec.bias, spec.bias) + if allowed_masks is not None and mm not in allowed_masks: + continue + if allowed_biases is not None and mb not in allowed_biases: + continue + ww = bwd_dq_wave_warp(tile, hq) + configs.append( + FmhaKernelConfig( + family="bwd_dq_dk_dv", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline="qr", + tile_m0=tile[0], + tile_n0=tile[1], + tile_k0=tile[2], + tile_n1=tile[3] if len(tile) > 3 else hv, + tile_k1=tile[4] if len(tile) > 4 else tile[2], + tile_k0max=tile[5] if len(tile) > 5 else hq, + tile_bwd6=tile[6] if len(tile) > 6 else 0, + tile_bwd7=tile[7] if len(tile) > 7 else 0, + tile_bwd8=tile[8] if len(tile) > 8 else 0, + pad_s=_pad_val(spec.spad), + pad_sk=_pad_val(spec.skpad), + pad_d=_pad_val(spec.dpad), + pad_dv=_pad_val(spec.dvpad), + mask=mm, + bias=mb, + dbias=(spec.dbias == "t"), + dropout=(spec.dropout != "no"), + dropout_variant=spec.dropout, + deterministic=(spec.deterministic == "t"), + gfx_arch=arch, + **ww, + ) + ) + + # dq_dk_dv: extra tiles + for (hq, hv), extra_entries in BWD_DQ_DK_DV_EXTRA_TILES.items(): + for tile, tag, is_batch_only in extra_entries: + dq_extra_specs = get_bwd_dq_dk_dv_extra_pipelines( + dtype, is_small=is_batch_only, receipt=receipt + ) + for mode in ["batch"] if is_batch_only else MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in dq_extra_specs: + mm = _MASK_MAP.get(spec.mask, spec.mask) + mb = _BIAS_MAP.get(spec.bias, spec.bias) + if allowed_masks is not None and mm not in allowed_masks: + continue + if allowed_biases is not None and mb not in allowed_biases: + continue + ww = bwd_dq_wave_warp(tile, hq, trload=(tag == "trload")) + configs.append( + FmhaKernelConfig( + family="bwd_dq_dk_dv", + data_type=dtype, + mode=mode, + hdim_q=hq, + hdim_v=hv, + pipeline="qr", + tile_m0=tile[0], + tile_n0=tile[1], + tile_k0=tile[2], + tile_n1=tile[3] if len(tile) > 3 else hv, + tile_k1=tile[4] if len(tile) > 4 else tile[2], + tile_k0max=tile[5] if len(tile) > 5 else hq, + tile_bwd6=tile[6] if len(tile) > 6 else 0, + tile_bwd7=tile[7] if len(tile) > 7 else 0, + tile_bwd8=tile[8] if len(tile) > 8 else 0, + tile_tag=tag, + use_trload=(tag == "trload"), + pad_s=_pad_val(spec.spad), + pad_sk=_pad_val(spec.skpad), + pad_d=_pad_val(spec.dpad), + pad_dv=_pad_val(spec.dvpad), + mask=mm, + bias=mb, + dbias=(spec.dbias == "t"), + dropout=(spec.dropout != "no"), + dropout_variant=spec.dropout, + deterministic=(spec.deterministic == "t"), + gfx_arch=arch, + **ww, + ) + ) + + # convert_dq + for hd in BWD_CONVERT_DQ_HDIMS: + cvt_specs = get_bwd_convert_dq_pipelines(dtype, hd) + n_tile_groups = BWD_CONVERT_DQ_TILE_GROUPS.get(hd, 1) + for mode in MODES: + if allowed_modes is not None and mode not in allowed_modes: + continue + for spec in cvt_specs: + if mode == "group" and spec.spad != "t": + continue + for tile_grp in range(n_tile_groups): + configs.append( + FmhaKernelConfig( + family="bwd_convert_dq", + data_type=dtype, + mode=mode, + hdim_q=hd, + hdim_v=hd, + pipeline="qr", + tile_m0=64, + tile_tag=f"g{tile_grp}" if tile_grp > 0 else "", + pad_s=_pad_val(spec.spad), + pad_d=_pad_val(spec.dpad), + deterministic=(spec.deterministic == "t"), + gfx_arch=arch, + ) + ) + return configs + + +# ============================================================================= +# Filter utility +# ============================================================================= + + +def apply_filter( + configs: List[FmhaKernelConfig], expr: str = "", filter_file: str = "" +) -> List[FmhaKernelConfig]: + """Apply user-defined filters to a config list.""" + result = configs + + if filter_file: + import importlib.util + + spec = importlib.util.spec_from_file_location("user_filter", filter_file) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + fn = getattr(mod, "filter_config") + result = [c for c in result if fn(c)] + + if expr: + result = [c for c in result if eval(expr, {"c": c})] # noqa: S307 + + return result + + +# ============================================================================= +# CLI +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser(description="FMHA instance enumeration") + parser.add_argument("config", help="Sweep config JSON") + parser.add_argument("--arch", default="gfx950") + parser.add_argument("--receipt", type=int, default=0) + parser.add_argument( + "--filter", + dest="filter_expr", + default="", + help='Python expression per config, e.g. "c.hdim_q == 128"', + ) + parser.add_argument( + "--filter-file", + default="", + help="Path to .py file with filter_config(c) -> bool", + ) + parser.add_argument("--list", action="store_true") + parser.add_argument("--count-only", action="store_true") + args = parser.parse_args() + + configs = expand_sweep(args.config, args.arch, args.receipt) + before = len(configs) + configs = apply_filter(configs, args.filter_expr, args.filter_file) + filtered = before - len(configs) + + print( + f"Expanded {args.config} -> {before} configs" + f"{f' (filtered {filtered}, kept {len(configs)})' if filtered else ''}" + ) + + if args.count_only: + return + + if args.list: + for i, c in enumerate(configs): + print(f" [{i}] {c.name}") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/codegen/fmha/symbol_map.py b/dispatcher/codegen/fmha/symbol_map.py new file mode 100644 index 0000000000..f6ab6adb4b --- /dev/null +++ b/dispatcher/codegen/fmha/symbol_map.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import json +import hashlib + +# Architecture tag → C++ arch trait type. +# Source: CK's include/ck_tile/core/arch/amd_gpu_traits.hpp +# gfx9* → gfx9_t, gfx11* → gfx11_t, gfx12* → gfx12_t. +ARCH_TAG_MAP = { + "gfx90a": "ck_tile::gfx9_t", + "gfx942": "ck_tile::gfx9_t", + "gfx950": "ck_tile::gfx9_t", + "gfx1100": "ck_tile::gfx11_t", + "gfx1201": "ck_tile::gfx12_t", +} + +# Architecture → preprocessor guard for conditional compilation. +# Source: HIP compiler predefined macros (__gfx90a__, __gfx942__, etc.). +ARCH_PREPROC_MAP = { + "gfx90a": "defined(__gfx90a__)", + "gfx942": "defined(__gfx942__)", + "gfx950": "defined(__gfx950__)", + "gfx1100": "defined(__gfx1100__)", + "gfx1201": "defined(__gfx1201__)", +} + +# Forward dtype → C++ type config struct. +# Source: example/ck_tile/01_fmha/fmha_fwd.hpp FmhaFwdTypeConfig<> specializations +# and codegen/cpp_symbol_map.py FWD_DTYPE_MAP. +FWD_DTYPE_MAP = { + "fp32": "FmhaFwdFp32", + "fp16": "FmhaFwdFp16", + "bf16": "FmhaFwdBf16", + "fp8": "FmhaFwdFp8", + "bf8": "FmhaFwdBf8", + "fp8fp16": "FmhaFwdFp8Fp16", + "fp8bf16": "FmhaFwdFp8Bf16", + "fp8fp32": "FmhaFwdFp8Fp32", +} + +# Backward dtype → C++ type config struct. +# Source: example/ck_tile/01_fmha/fmha_bwd.hpp FmhaBwdTypeConfig<> specializations. +# BWD currently only supports fp16/bf16/fp32. +BWD_DTYPE_MAP = { + "fp32": "FmhaBwdFp32", + "fp16": "FmhaBwdFp16", + "bf16": "FmhaBwdBf16", +} + +# Kernel family → C++ enum. +# Source: include/ck_tile/dispatcher/fmha_types.hpp FmhaKernelFamily enum. +KERNEL_FAMILY_TO_ENUM = { + "fwd": "FmhaKernelFamily::Fwd", + "fwd_pagedkv": "FmhaKernelFamily::FwdPagedKv", + "fwd_splitkv": "FmhaKernelFamily::FwdSplitKv", + "fwd_splitkv_combine": "FmhaKernelFamily::FwdSplitKvCombine", + "fwd_appendkv": "FmhaKernelFamily::FwdAppendKv", + "batch_prefill": "FmhaKernelFamily::BatchPrefill", + "bwd_dot_do_o": "FmhaKernelFamily::BwdDotDoO", + "bwd_dq_dk_dv": "FmhaKernelFamily::BwdDqDkDv", + "bwd_convert_dq": "FmhaKernelFamily::BwdConvertDq", +} + +# API family → C++ enum. +# Source: include/ck_tile/dispatcher/fmha_types.hpp FmhaApiFamily enum. +API_FAMILY_TO_ENUM = { + "fwd": "FmhaApiFamily::Fwd", + "fwd_pagedkv": "FmhaApiFamily::FwdPagedKv", + "fwd_splitkv": "FmhaApiFamily::FwdSplitKv", + "fwd_appendkv": "FmhaApiFamily::FwdAppendKv", + "batch_prefill": "FmhaApiFamily::BatchPrefill", + "bwd": "FmhaApiFamily::Bwd", +} + +# Mask type → canonical form and C++ types. +# Source: include/ck_tile/ops/fmha/block/block_attention_mask.hpp +# SimplifiedGenericAttentionMask and GenericAttentionMask. +MASK_CANONICAL = { + "no": "no", + "no_mask": "no", + "causal": "top_left", + "top_left": "top_left", + "t": "top_left", + "bottom_right": "bottom_right", + "b": "bottom_right", + "generic": "generic", + "window_generic": "generic", + "g": "generic", +} + +MASK_TO_CPP = { + "no": "ck_tile::SimplifiedGenericAttentionMask", + "top_left": "ck_tile::SimplifiedGenericAttentionMask", + "bottom_right": "ck_tile::SimplifiedGenericAttentionMask", + "generic": "ck_tile::GenericAttentionMask", +} + +MASK_TO_CPP_GENERIC = { + "no": "FmhaMasks::NoMask", + "top_left": "FmhaMasks::CausalMask", + "bottom_right": "FmhaMasks::CausalMask", + "generic": "FmhaMasks::GenericMask", +} + +MASK_TO_INT = { + "no": 0, + "top_left": 1, + "bottom_right": 2, + "generic": 3, +} + +# Bias type → canonical form and C++ enum. +# Source: include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp. +BIAS_CANONICAL = { + "no": "no", + "no_bias": "no", + "bias": "bias", + "elementwise": "bias", + "elementwise_bias": "bias", + "alibi": "alibi", +} + +BIAS_TO_CPP = { + "no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI", +} + +BIAS_TO_INT = { + "no": 0, + "bias": 1, + "alibi": 2, +} + +# Quantization scale type → canonical form and C++ enum. +# Source: include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp. +QSCALE_CANONICAL = { + "no": "no", + "no_scale": "no", + "pertensor": "pertensor", + "blockscale": "blockscale", + "kv_blockscale": "kv_blockscale", +} + +QSCALE_TO_CPP = { + "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", + "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", + "kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE", +} + +QSCALE_TO_INT = { + "no": 0, + "pertensor": 1, + "blockscale": 2, + "kv_blockscale": 3, +} + +# Rotary embedding type → canonical form and C++ enum. +# Source: include/ck_tile/ops/fmha/block/rotary_embedding_enum.hpp. +ROPE_CANONICAL = { + "none": "none", + "no": "none", + "inter": "inter", + "interleaved": "inter", + "half": "half", + "half_rotated": "half", +} + +ROPE_TO_CPP = { + "none": "ck_tile::RotaryEmbeddingEnum::NONE", + "inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED", +} + +ROPE_TO_INT = { + "none": 0, + "inter": 1, + "half": 2, +} + +# V layout → C++ bool (true = row-major, false = column-major). +# Source: TileFmhaShape<..., IsVLayoutRowMajor> template parameter. +LAYOUT_TO_BOOL = { + "r": "true", + "row": "true", + "row_major": "true", + "c": "false", + "col": "false", + "col_major": "false", +} + +# KV cache memory layout → canonical form and C++ enum. +# Source: include/ck_tile/ops/fmha/block/block_attention_kv_cache.hpp. +KV_MEMORY_LAYOUT_CANONICAL = { + "vectorized": "vectorized", + "linear": "linear", +} + +KV_MEMORY_LAYOUT_TO_CPP = { + "vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT", + "linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT", +} + +KV_MEMORY_LAYOUT_TO_INT = { + "vectorized": 0, + "linear": 1, +} + +# KV lookup table type → canonical form and C++ enum. +# Source: include/ck_tile/ops/fmha/block/block_attention_kv_cache.hpp. +KV_LOOKUP_CANONICAL = { + "sglang": "sglang", + "vllm": "vllm", +} + +KV_LOOKUP_TO_CPP = { + "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", + "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", +} + +KV_LOOKUP_TO_INT = { + "vllm": 0, + "sglang": 1, +} + +# Pipeline tag → C++ pipeline class. +# Source: include/ck_tile/ops/fmha/pipeline/ — one header per pipeline variant. +PIPELINE_TO_CPP = { + "qr": "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs": "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "v3": "ck_tile::BlockFmhaFwdV3Pipeline", + "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", + "qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS", + "qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", + "appendkv": "ck_tile::BlockFmhaFwdAppendKVPipeline", + "batch_prefill_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", +} + +# Pipeline tag → C++ pipeline enum value. +# Source: include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp. +PIPELINE_ENUM_TO_CPP = { + "qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", + "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", + "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "batch_prefill_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + +BOOL_MAP = { + True: "true", + False: "false", + "t": "true", + "f": "false", +} + + +def canonical_mask(value: str) -> str: + return MASK_CANONICAL.get(value, value) + + +def canonical_bias(value: str) -> str: + return BIAS_CANONICAL.get(value, value) + + +def canonical_qscale(value: str) -> str: + return QSCALE_CANONICAL.get(value, value) + + +def canonical_rope(value: str) -> str: + return ROPE_CANONICAL.get(value, value) + + +def canonical_kv_memory_layout(value: str) -> str: + return KV_MEMORY_LAYOUT_CANONICAL.get(value, value) + + +def canonical_kv_lookup(value: str) -> str: + return KV_LOOKUP_CANONICAL.get(value, value) + + +def sanitize_token(value) -> str: + return str(value).replace("::", "_").replace("/", "_").replace(" ", "_") + + +def kernel_name_from_config(config: dict) -> str: + sig = config["signature"] + alg = config["algorithm"] + + family = sanitize_token(sig["family"]) + dtype = sanitize_token(sig["data_type"]) + mode = sanitize_token(sig["mode"]) + vlayout = sanitize_token(sig["vlayout"]) + mask = sanitize_token(canonical_mask(sig["mask"])) + bias = sanitize_token(canonical_bias(sig["bias"])) + qscale = sanitize_token(canonical_qscale(sig["qscale"])) + rope = sanitize_token(canonical_rope(sig["rope"])) + kv_memory = sanitize_token(canonical_kv_memory_layout(sig["kv_memory_layout"])) + kv_lookup = sanitize_token(canonical_kv_lookup(sig["kv_lookup_table"])) + pipeline = sanitize_token(alg["pipeline"]) + + canonical_blob = json.dumps( + { + "family": family, + "dtype": dtype, + "mode": mode, + "vlayout": vlayout, + "mask": mask, + "bias": bias, + "qscale": qscale, + "rope": rope, + "kv_memory": kv_memory, + "kv_lookup": kv_lookup, + "sig": sig, + "alg": alg, + }, + sort_keys=True, + ).encode("utf-8") + digest = hashlib.sha1(canonical_blob).hexdigest()[:12] + + return ( + f"fmha_{family}_{dtype}_{mode}_h{sig['hdim_q']}x{sig['hdim_v']}" + f"_{pipeline}_{digest}" + ) diff --git a/dispatcher/codegen/fmha/validation.py b/dispatcher/codegen/fmha/validation.py new file mode 100644 index 0000000000..20b3a00540 --- /dev/null +++ b/dispatcher/codegen/fmha/validation.py @@ -0,0 +1,921 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +FMHA validation and kernel specifications. + +Architecture-specific data (dtypes, pipelines, hdims, tile tables) is stored in +``fmha_arch_specs.json`` so that it can be edited without touching Python code. +Common GPU hardware data (element sizes, warp size, LDS capacity) is imported +from the parent ``arch_specs_generated`` module (generated from ``arch_specs.json``). + +This file provides: + - JSON loading helpers + - Tile constraints (per-arch rules that reject invalid tiles) + - Feature compatibility rules (pipeline × feature flag interactions) + - Receipt filters and profiles (deployment-specific kernel subsets) + - Config validation for the AOT codegen path +""" + +import json +import sys +from dataclasses import dataclass, field +from enum import IntEnum +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Tuple + +# Ensure this directory and parent codegen/ are on sys.path for sibling imports +_THIS_DIR = Path(__file__).resolve().parent +_CODEGEN_DIR = _THIS_DIR.parent +sys.path.insert(0, str(_THIS_DIR)) +sys.path.insert(0, str(_CODEGEN_DIR)) + +from symbol_map import ( # noqa: E402 + BWD_DTYPE_MAP, + FWD_DTYPE_MAP, + canonical_bias, + canonical_mask, + canonical_qscale, +) + +# Import shared hardware data from parent arch_specs_generated (generated from +# arch_specs.json by generate_arch_specs.py). Falls back to inline defaults if +# the generated module is unavailable (e.g. in standalone testing). +try: + from arch_specs_generated import ELEMENT_SIZE_MAP as _PARENT_ELEMENT_SIZES # noqa: E402 +except ImportError: + _PARENT_ELEMENT_SIZES = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "pk_fp4": 0.5, + "int32": 4, + } + + +# ============================================================================= +# JSON data loading +# ============================================================================= + +_FMHA_SPECS_PATH = _THIS_DIR / "fmha_arch_specs.json" + + +def _load_fmha_specs() -> dict: + """Load fmha_arch_specs.json (cached after first call).""" + if not hasattr(_load_fmha_specs, "_cache"): + with open(_FMHA_SPECS_PATH) as f: + _load_fmha_specs._cache = json.load(f) + return _load_fmha_specs._cache + + +def _build_element_sizes() -> Dict[str, int]: + """Merge parent element sizes with FMHA-specific composite dtypes.""" + base = {k: int(v) for k, v in _PARENT_ELEMENT_SIZES.items()} + base.update(_load_fmha_specs().get("fmha_element_sizes", {})) + return base + + +# ============================================================================= +# 1. Architecture capabilities (loaded from fmha_arch_specs.json) +# ============================================================================= + + +def _build_arch_dtypes() -> Dict[str, List[str]]: + """Build ARCH_DTYPES from JSON architectures.""" + return { + arch: info["supported_dtypes"] + for arch, info in _load_fmha_specs()["architectures"].items() + } + + +def _build_supported_hdims() -> Dict[str, List[Tuple[int, int]]]: + """Build SUPPORTED_HDIMS from JSON, converting [q,v] lists to tuples.""" + return { + dtype: [tuple(pair) for pair in pairs] + for dtype, pairs in _load_fmha_specs()["supported_hdims"].items() + if dtype != "_comment" + } + + +def _build_arch_metadata() -> Dict[str, dict]: + """Build ARCH_METADATA from JSON architectures.""" + return dict(_load_fmha_specs()["architectures"]) + + +ARCH_DTYPES: Dict[str, List[str]] = _build_arch_dtypes() +SUPPORTED_HDIMS: Dict[str, List[Tuple[int, int]]] = _build_supported_hdims() +ARCH_METADATA: Dict[str, dict] = _build_arch_metadata() + + +# ============================================================================= +# 2. Tile hardware parameters (loaded from fmha_arch_specs.json + parent arch_specs) +# ============================================================================= + + +def _build_warp_classes() -> Dict[str, List[Tuple[int, int, int]]]: + """Build WARP_CLASSES from JSON fmha_warp_tiles.""" + return { + dtype: [tuple(w) for w in warps] + for dtype, warps in _load_fmha_specs()["fmha_warp_tiles"].items() + if dtype != "_comment" + } + + +def _build_lds_limits() -> Dict[str, int]: + """Build LDS_LIMITS from JSON.""" + return dict(_load_fmha_specs()["lds_limits"]) + + +def _build_k0max_map() -> Dict[int, int]: + """Build K0_MAX_SUBMAX_MAP from JSON (string keys → int keys).""" + return { + int(k): v for k, v in _load_fmha_specs()["k0max_map"].items() if k != "_comment" + } + + +_specs = _load_fmha_specs() +_tile_ranges = _specs["tile_sweep_ranges"] + +LDS_LIMITS: Dict[str, int] = _build_lds_limits() +WARP_CLASSES: Dict[str, List[Tuple[int, int, int]]] = _build_warp_classes() +ELEMENT_SIZES: Dict[str, int] = _build_element_sizes() +VALID_BM0: List[int] = _tile_ranges["valid_bm0"] +VALID_BN0: List[int] = _tile_ranges["valid_bn0"] +VALID_BK0: List[int] = _tile_ranges["valid_bk0"] +K0_MAX_SUBMAX_MAP: Dict[int, int] = _build_k0max_map() + + +# ============================================================================= +# 3. Tile constraints +# ============================================================================= + + +def check_gfx9_tile_constraints( + dtype: str, + hdim_q: int, + hdim_v: int, + pipeline: str, + bm0: int, + bn0: int, + bk0: int, +) -> bool: + """Gfx9 compatibility rules. + + Source: fmha_fwd.py CompatibilityRuleFactoryGfx9.check_hdim_tile(). + Applies to gfx90a, gfx942, gfx950 for pipelines in {qr, qr_async, qs}. + Note: CK factory is stricter (bm0==128 only for non-128 hdims); we allow + {64, 128, 192, 256} to let the tile engine explore more configurations. + """ + if dtype == "fp32": + return True + if pipeline not in ("qr", "qr_async", "qs"): + return True + if (hdim_q, hdim_v) == (128, 128) and bn0 != 128: + return False + if (hdim_q, hdim_v) == (128, 128) and pipeline == "qr_async" and bm0 != 128: + return False + if (hdim_q, hdim_v) != (128, 128) and bm0 not in (64, 128, 192, 256): + return False + if (hdim_q, hdim_v) == (128, 128) and pipeline != "qr_async" and bk0 == 64: + return False + return True + + +def check_gfx950_tile_constraints( + hdim_q: int, + hdim_v: int, + pipeline: str, + bm0: int, + bn0: int, +) -> bool: + """Gfx950 trload/v3 constraints. + + Source: fmha_fwd.py CompatibilityRuleFactoryGfx950.check_tile_pipeline(). + Note: CK enforces biconditional (v3_tile ↔ v3_pipeline); we only enforce + v3_pipeline → v3_tile since non-v3 pipelines may still use bm0=256. + """ + if pipeline == "qr_async_trload": + if (hdim_q, hdim_v) == (128, 128) and bn0 == 128: + return False + if (hdim_q, hdim_v) not in [(64, 64), (128, 128)]: + return False + is_v3_tile = bm0 == 256 + is_v3_pipeline = pipeline == "qr_async_trload_v3" + # v3 pipeline requires bm0=256; other pipelines also allow bm0=256 + if is_v3_pipeline and not is_v3_tile: + return False + return True + + +def check_qr_mfma_insts( + arch: str, + hdim_q: int, + pipeline: str, + bn0: int, + bk0: int, + wn0: int, + wk0: int, +) -> bool: + """NumMfmaInsts % 8 == 0 check. + + Source: block_fmha_pipeline_qr_ks_vs.hpp static_assert at line ~490. + Full C++ formula: (kM0/WarpM)*(kN0/WarpN)*(kK0/WarpK) / (MWarp*NWarp). + We simplify to (bn0/wn0)*(bk0/wk0), omitting (bm0/wm0)/(rm0*rn0) which + equals 1 for all current fp16/bf16/fp32/fp8 tiles, or a power-of-2 factor + for mxfp8/mxfp4 that doesn't change the mod-8 result. This is conservative: + it can only reject tiles the full formula would also reject, never the reverse. + Only applies to qr pipeline + hdim_q==256 + CDNA (gfx9*). + """ + if pipeline != "qr" or hdim_q != 256: + return True + if not arch.startswith("gfx9"): + return True + num_mfma = (bn0 // wn0) * (bk0 // wk0) + if num_mfma % 8 != 0: + return False + return True + + +def tile_passes_all_constraints( + arch: str, + dtype: str, + hdim_q: int, + hdim_v: int, + pipeline: str, + bm0: int, + bn0: int, + bk0: int, + wm0: int, + wn0: int, + wk0: int, +) -> bool: + """Master constraint check — returns True if the tile is valid.""" + elem_size = ELEMENT_SIZES.get(dtype, 2) + lds_limit = LDS_LIMITS.get(pipeline, 65536) + + # LDS capacity check (pipeline-dependent formula) + if pipeline in ("qr_async", "qr_async_trload", "qr_async_trload_v3"): + # Async pipeline: Q is in registers. LDS holds NumKVLdsBuffers (=3) copies of + # max(SingleKSize, SingleVSize). Derived from GetSmemSizeKV() in + # block_fmha_pipeline_qx_ks_vs_custom_policy.hpp. + # + # SingleVSize formula (MakeVLdsBlockDescriptor): + # Banks=32, PixelsPerRow = Banks*4/sizeof(dtype) = 32*4/elem_size + # kKPack = 16/elem_size (GetSmemKPackV) + # NPerRow = PixelsPerRow/kKPack + # SingleVSize = (bk1/kKPack) * (hdim_v/NPerRow) * (PixelsPerRow + kKPack) + # For bf16: PixelsPerRow=64, kKPack=8, NPerRow=8 + # SingleVSize = (32/8)*(hdim_v/8)*(64+8) = 4*(hdim_v/8)*72 = 36*hdim_v + # + # SingleKSize formula (GetSingleSmemElementSpaceSize, async branch): + # KPack = 16/elem_size, KVector = alignment (gfx950: 16/elem_size = 8 for bf16) + # LanesPerK = bk0/KVector, LaneGroups = 64/LanesPerK + # NumIssues = bn0/(LaneGroups*NumWarps) + # SingleKSize = NumIssues*NumWarps*(64*KVector + KPack) + # + bk1 = 32 # kK1 in TileFmhaShape — design choice from fmha_fwd.py tile defs + num_warps = bm0 // wm0 + # Banks: arch.hpp get_n_lds_banks() — 64 for gfx950, 32 for older + banks = 64 if arch == "gfx950" else 32 + pixels_per_row = banks * 4 // elem_size # Banks * 4bytes / sizeof(dtype) + k_pack = 16 // elem_size # GetSmemKPackV: 16 / sizeof(dtype) + n_per_row = pixels_per_row // k_pack + single_v = (bk1 // k_pack) * (hdim_v // n_per_row) * (pixels_per_row + k_pack) + + # KVector: GetAlignmentK in custom_policy.hpp — MaxLoadSizeInBytes / sizeof(dtype) + # gfx950 uses dwordx4 (16 bytes), older uses dword (4 bytes) + k_vector = 16 // elem_size if arch == "gfx950" else 4 // elem_size + lanes_per_k = bk0 // k_vector if k_vector > 0 else 1 + lane_groups = 64 // lanes_per_k if lanes_per_k > 0 else 1 # WarpSize=64 + num_issues = ( + bn0 // (lane_groups * num_warps) if (lane_groups * num_warps) > 0 else 0 + ) + single_k = num_issues * num_warps * (64 * k_vector + k_pack) + + single_buf_bytes = max(single_k, single_v) * elem_size + # NumPrefetchK = NumPrefetchV = 3 (async_default_policy.hpp) + num_kv_buffers = 3 + # Q uses registers (QLoadOnce=true), so GetSmemSizeQ() = 0. + total_lds = single_buf_bytes * num_kv_buffers + # gfx950 HW LDS limit: arch.hpp get_smem_capacity() = 163840 (160 KiB) + if total_lds > 163840: + return False + else: + # Non-async (qr/qs): Q and K tiles share LDS simultaneously + if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit: + return False + # bk0 range + if bk0 > hdim_q: + return False + # hdim_q divisibility (tile_fmha_shape.hpp:60) + if hdim_q % bk0 != 0: + return False + # Warp alignment + if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0: + return False + # MFMA inst count + if not check_qr_mfma_insts(arch, hdim_q, pipeline, bn0, bk0, wn0, wk0): + return False + # Async DMA distribution constraint (MakeKLdsStoreBlockDescriptor, custom_policy.hpp). + # NumIssues = kNPerBlock / (LaneGroups * NumWarps) must be a positive integer, where + # LaneGroups = WarpSize / LanesPerK = 64 / (bk0 / KVector). + # Equivalently: (bn0 * bk0) % (kBlockSize * KVector) == 0. + # KVector = MaxLoadSizeInBytes / sizeof(dtype): gfx950=16/2=8, older=4/2=2 for bf16. + if pipeline == "qr_async" and arch.startswith("gfx9"): + kvector = 16 // elem_size if arch == "gfx950" else 4 // elem_size + num_warps = bm0 // wm0 + block_size = num_warps * 64 # WarpSize = 64 + if (bn0 * bk0) % (block_size * kvector) != 0: + return False + # Arch constraints + if arch in ("gfx90a", "gfx942", "gfx950"): + if not check_gfx9_tile_constraints( + dtype, hdim_q, hdim_v, pipeline, bm0, bn0, bk0 + ): + return False + if arch == "gfx950": + if not check_gfx950_tile_constraints(hdim_q, hdim_v, pipeline, bm0, bn0): + return False + return True + + +# ============================================================================= +# 4. Feature compatibility rules +# ============================================================================= + +# Supported mask, bias, and boolean values for feature products. +# These are the template enum values in CK's FMHA traits structs. +MASKS = ["no", "causal", "generic"] +BIASES = ["no", "bias", "alibi"] +BOOLS = ["t", "f"] + +# Dtype groups matching CK's _DT_* classification in fmha_fwd.py factory classes. +DT_FP16_BF16 = {"fp16", "bf16"} +DT_FP8 = {"fp8bf16", "fp8", "bf8"} +DT_FP8FP32 = {"fp8fp32"} +DT_FP32 = {"fp32"} + + +def check_logits_bias(logits: str, bias: str) -> bool: + """logits_soft_cap requires no bias. + + Source: fmha_fwd.py CompatibilityRuleFactory.check_feature(). + """ + return not (logits == "t" and bias != "no") + + +def check_group_mode_padding(mode: str, spad: str, skpad: str) -> bool: + """Group mode requires spad=t and skpad=t. + + Source: fmha_fwd.py CompatibilityRuleFactory.check_feature() + + block_fmha_pipeline static_asserts for padding. + """ + if mode == "group": + return spad == "t" and skpad == "t" + return True + + +# ============================================================================= +# 5. Variant-specific tile tables (loaded from fmha_arch_specs.json) +# ============================================================================= + + +def _build_bwd_tiles() -> Tuple[ + Dict[Tuple[int, int], Tuple[int, ...]], + Dict[Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]]], + Dict[Tuple[int, int, int, str], dict], +]: + """Build BWD tile tables from JSON.""" + bwd = _load_fmha_specs()["bwd_tiles"] + + # Main tiles: "hdimq_hdimv" -> 9-tuple + main = {} + for k, v in bwd["dq_dk_dv_fp16"].items(): + hq, hv = map(int, k.split("_")) + main[(hq, hv)] = tuple(v) + + # Extra tiles: "hdimq_hdimv" -> [(tile, tag, batch_only), ...] + extra = {} + for k, entries in bwd.get("dq_dk_dv_extra", {}).items(): + hq, hv = map(int, k.split("_")) + extra[(hq, hv)] = [ + (tuple(e["tile"]), e["tag"], e["batch_only"]) for e in entries + ] + + # Wave/warp lookup: "bm0_bn0_bk0_trload" -> {wave, warp_k1} + ww = {} + for k, v in _load_fmha_specs()["bwd_wave_warp"].items(): + if k.startswith("_"): + continue + parts = k.split("_") + key = (int(parts[0]), int(parts[1]), int(parts[2]), parts[3]) + ww[key] = {"wave": tuple(v["wave"]), "warp_k1": v["warp_k1"]} + + return main, extra, ww + + +def _build_splitkv_hdims() -> Tuple[List[int], List[int]]: + """Build SplitKV combine hdim lists from JSON.""" + skv = _load_fmha_specs()["splitkv_combine"] + return skv["hdims_fp16"], skv["hdims_fp8"] + + +_bwd_main, _bwd_extra, _bwd_ww = _build_bwd_tiles() +_skv_fp16, _skv_fp8 = _build_splitkv_hdims() + +SPLITKV_COMBINE_HDIMS_FP16: List[int] = _skv_fp16 +SPLITKV_COMBINE_HDIMS_FP8: List[int] = _skv_fp8 +BWD_DQ_DK_DV_TILES_FP16: Dict[Tuple[int, int], Tuple[int, ...]] = _bwd_main +BWD_DQ_DK_DV_EXTRA_TILES: Dict[ + Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]] +] = _bwd_extra +BWD_DQ_WAVE_WARP: Dict[Tuple[int, int, int, str], dict] = _bwd_ww + +_bwd_json = _load_fmha_specs()["bwd_tiles"] +BWD_EXTRA_PAD_COMBOS: List[Tuple[str, str]] = [ + tuple(p) for p in _bwd_json["extra_pad_combos"] +] +BWD_SMALL_DROPOUTS: List[str] = _bwd_json["small_dropouts"] +BWD_DOT_DO_O_HDIMS: List[int] = _bwd_json["dot_do_o_hdims"] +BWD_CONVERT_DQ_HDIMS: List[int] = _bwd_json["convert_dq_hdims"] +BWD_CONVERT_DQ_TILE_GROUPS: Dict[int, int] = { + int(k): v for k, v in _bwd_json["convert_dq_tile_groups"].items() +} +BWD_DROPOUTS: List[str] = _bwd_json["dropouts"] +BWD_PAD_COMBOS: List[Tuple[str, str]] = [tuple(p) for p in _bwd_json["pad_combos"]] + + +# ============================================================================= +# 6. Receipt filters +# ============================================================================= + + +class Receipt(IntEnum): + """Named receipt levels for deployment profiles. + + These are deployment-specific filters, not derived from C++ constraints. + They control which kernel subsets are emitted for different integration + targets (PyTorch, AITER, Flash-Attention, etc.). + """ + + CK_DEFAULT = 0 + CK_EXTENDED = 1 + FLASH_FWD = 2 + FLASH_BWD = 3 + PYTORCH = 4 + AITER_BATCH = 100 + AITER_GROUP = 200 + AITER_BWD_BATCH = 300 + AITER_BWD_GROUP = 400 + AITER_CPP = 600 + FP32_ALL = 800 + FP32_MIN = 801 + FP8_TEST = 888 + + +RECEIPT_FILTERS: Dict[int, Callable[[str, object], bool]] = { + 0: lambda dtype, spec: dtype != "fp32", + 2: lambda dtype, spec: ( + dtype in ("fp16", "bf16") + and getattr(spec, "bias", "no") in ("no", "alibi") + and getattr(spec, "qscale", "no") == "no" + and getattr(spec, "skip", "f") == "f" + and getattr(spec, "sink", "f") == "f" + ), + 4: lambda dtype, spec: ( + dtype in ("fp16", "bf16") + and getattr(spec, "bias", "no") in ("no", "bias") + and getattr(spec, "qscale", "no") == "no" + and getattr(spec, "skip", "f") == "f" + and getattr(spec, "logits", "f") == "f" + ), + 100: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"), + 200: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"), + 600: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"), + 888: lambda dtype, spec: dtype in ("fp8bf16", "fp8fp32"), + 800: lambda dtype, spec: ( + dtype == "fp32" + and getattr(spec, "skip", "f") == "f" + and getattr(spec, "logits", "f") == "f" + ), +} + + +def receipt_filter(receipt: int, dtype: str, spec) -> bool: + """Apply receipt-level filter. Returns True if the kernel should be kept.""" + fn = RECEIPT_FILTERS.get(receipt) + if fn is None: + return dtype != "fp32" + return fn(dtype, spec) + + +# ============================================================================= +# 7. Profiles +# ============================================================================= + +PROFILE_ALIASES: Dict[str, str] = {str(r.value): r.name.lower() for r in Receipt} + + +@dataclass(frozen=True) +class FmhaProfile: + name: str + predicate: Callable[[dict], bool] + + def allows(self, config: dict) -> bool: + return self.predicate(config) + + +def _dtype_is(config: dict, allowed: Iterable[str]) -> bool: + return config["signature"]["data_type"] in set(allowed) + + +def _mode_is(config: dict, allowed: Iterable[str]) -> bool: + return config["signature"]["mode"] in set(allowed) + + +def _family_is(config: dict, allowed: Iterable[str]) -> bool: + return config["signature"]["family"] in set(allowed) + + +def _common_row_major_filter(config: dict) -> bool: + return config["signature"]["vlayout"] == "r" + + +def _bias_is(config: dict, allowed: Iterable[str]) -> bool: + return canonical_bias(config["signature"]["bias"]) in set(allowed) + + +def _qscale_is(config: dict, allowed: Iterable[str]) -> bool: + return canonical_qscale(config["signature"]["qscale"]) in set(allowed) + + +def _no_skip_or_logits(config: dict) -> bool: + return (not config["signature"]["skip_min_seqlen_q"]) and ( + not config["signature"]["logits"] + ) + + +PROFILES: Dict[str, FmhaProfile] = { + "ck_default": FmhaProfile( + "ck_default", lambda c: c["signature"]["data_type"] != "fp32" + ), + "ck_extended": FmhaProfile( + "ck_extended", lambda c: c["signature"]["data_type"] != "fp32" + ), + "flash_fwd": FmhaProfile( + "flash_fwd", + lambda c: ( + _family_is(c, {"fwd", "fwd_splitkv", "fwd_appendkv", "fwd_pagedkv"}) + and _dtype_is(c, {"fp16", "bf16"}) + and _common_row_major_filter(c) + and _bias_is(c, {"no", "alibi"}) + and _qscale_is(c, {"no"}) + and not c["signature"]["skip_min_seqlen_q"] + ), + ), + "flash_bwd": FmhaProfile( + "flash_bwd", + lambda c: ( + _family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"}) + and _dtype_is(c, {"fp16", "bf16"}) + ), + ), + "pytorch": FmhaProfile( + "pytorch", + lambda c: ( + _dtype_is(c, {"fp16", "bf16"}) + and _common_row_major_filter(c) + and _bias_is(c, {"no", "bias"}) + and _qscale_is(c, {"no"}) + and _no_skip_or_logits(c) + and not c["signature"].get("sink", False) + ), + ), + "aiter_batch": FmhaProfile( + "aiter_batch", + lambda c: ( + _dtype_is(c, {"fp16", "bf16", "fp8bf16"}) + and _mode_is(c, {"batch"}) + and _common_row_major_filter(c) + and ( + c["signature"]["data_type"] != "fp8bf16" + or c["signature"]["hdim_q"] in {128, 192} + ) + ), + ), + "aiter_group": FmhaProfile( + "aiter_group", + lambda c: ( + _dtype_is(c, {"fp16", "bf16", "fp8bf16"}) + and _mode_is(c, {"group"}) + and _common_row_major_filter(c) + ), + ), + "aiter_bwd_batch": FmhaProfile( + "aiter_bwd_batch", + lambda c: ( + _family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"}) + and _dtype_is(c, {"fp16", "bf16"}) + and _mode_is(c, {"batch"}) + ), + ), + "aiter_bwd_group": FmhaProfile( + "aiter_bwd_group", + lambda c: ( + _family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"}) + and _dtype_is(c, {"fp16", "bf16"}) + and _mode_is(c, {"group"}) + ), + ), + "aiter_cpp": FmhaProfile( + "aiter_cpp", + lambda c: ( + _dtype_is(c, {"fp16", "bf16", "fp8bf16"}) + and _common_row_major_filter(c) + and ( + c["signature"]["data_type"] != "fp8bf16" + or c["signature"]["hdim_q"] in {128, 192} + ) + ), + ), + "fp32_all": FmhaProfile( + "fp32_all", lambda c: _dtype_is(c, {"fp32"}) and _no_skip_or_logits(c) + ), + "fp32_min": FmhaProfile( + "fp32_min", + lambda c: ( + _dtype_is(c, {"fp32"}) + and _mode_is(c, {"batch"}) + and c["signature"]["hdim_q"] in {48, 128} + and c["signature"]["hdim_v"] in {48, 128} + and canonical_bias(c["signature"]["bias"]) == "no" + and not c["signature"]["lse"] + and not c["signature"]["dropout"] + and canonical_qscale(c["signature"]["qscale"]) == "no" + ), + ), + "fp8_test": FmhaProfile( + "fp8_test", + lambda c: ( + _dtype_is(c, {"fp8bf16", "fp8fp32"}) + and c["signature"]["hdim_q"] in {128, 192} + and _common_row_major_filter(c) + ), + ), + "all": FmhaProfile("all", lambda _: True), +} + + +def normalize_profile( + profile: Optional[str] = None, receipt: Optional[str] = None +) -> str: + if profile: + return PROFILE_ALIASES.get(str(profile), str(profile)) + if receipt is not None: + return PROFILE_ALIASES.get(str(receipt), str(receipt)) + return "ck_default" + + +def get_profile( + profile: Optional[str] = None, receipt: Optional[str] = None +) -> FmhaProfile: + normalized = normalize_profile(profile=profile, receipt=receipt) + if normalized not in PROFILES: + raise KeyError(f"Unknown FMHA profile: {normalized}") + return PROFILES[normalized] + + +def profile_allows( + config: dict, profile: Optional[str] = None, receipt: Optional[str] = None +) -> bool: + return get_profile(profile=profile, receipt=receipt).allows(config) + + +# ============================================================================= +# 8. Validation helpers (for unified_fmha_codegen) +# ============================================================================= + +_DEFAULTS: dict = _load_fmha_specs()["defaults"] +_GLOBAL_RULES: dict = _load_fmha_specs()["global_rules"] + + +def load_arch_specs() -> dict: + """Return arch_specs dict compatible with unified_fmha_codegen. + + Combines FMHA-specific architecture data from fmha_arch_specs.json with + defaults, global rules, and splitkv combine params. + """ + specs = _load_fmha_specs() + return { + "architectures": ARCH_METADATA, + "defaults": _DEFAULTS, + "global_rules": _GLOBAL_RULES, + "splitkv_combine": specs["splitkv_combine"], + } + + +# ============================================================================= +# 9. Config validation (for unified_fmha_codegen) +# ============================================================================= + + +@dataclass +class ValidationResult: + valid: bool = True + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + + def add_error(self, msg: str): + self.valid = False + self.errors.append(msg) + + def add_warning(self, msg: str): + self.warnings.append(msg) + + +def validate_config( + config: dict, arch_specs: Optional[dict] = None +) -> "ValidationResult": + """Validate an FMHA kernel config against all rules.""" + arch_specs = arch_specs or load_arch_specs() + result = ValidationResult() + + sig = config["signature"] + alg = config["algorithm"] + arch = config["arch"] + + architectures = arch_specs.get("architectures", ARCH_METADATA) + if arch not in architectures: + result.add_error(f"Unsupported FMHA target architecture: {arch}") + return result + + arch_info = architectures[arch] + global_rules = arch_specs.get("global_rules", _GLOBAL_RULES) + dtype = sig["data_type"] + family = sig["family"] + pipeline = alg["pipeline"] + canonical_mask(sig["mask"]) + bias = canonical_bias(sig["bias"]) + + # Family validation + supported_families = { + "fwd", + "fwd_pagedkv", + "fwd_splitkv", + "fwd_splitkv_combine", + "fwd_appendkv", + "batch_prefill", + "bwd_dot_do_o", + "bwd_dq_dk_dv", + "bwd_convert_dq", + } + if family not in supported_families: + result.add_error(f"Unsupported FMHA family: {family}") + + # Dtype validation + supported_dtypes = set(arch_info["supported_dtypes"]) + if dtype not in supported_dtypes: + result.add_error(f"dtype {dtype} is not supported on {arch}") + + if family.startswith("bwd") and dtype not in BWD_DTYPE_MAP: + result.add_error( + f"Backward family {family} only supports {sorted(BWD_DTYPE_MAP)}" + ) + + if ( + family.startswith("fwd") + and not family.startswith("fwd_append") + and dtype not in FWD_DTYPE_MAP + ): + result.add_error(f"Forward family {family} does not recognize dtype {dtype}") + + # Pipeline validation + if ( + family != "fwd_splitkv_combine" + and pipeline not in arch_info["supported_pipelines"] + ): + result.add_error(f"pipeline {pipeline} is not supported on {arch}") + + if pipeline in {"v3", "qr_async_trload_v3"} and not arch_info.get( + "supports_v3", False + ): + result.add_warning(f"v3 pipeline on {arch} requires supports_v3 in arch specs") + + if pipeline == "qr_async_trload" and not arch_info.get("supports_trload", False): + result.add_error("qr_async_trload requires a trload-capable architecture") + + # Global rules + hdim_q = sig["hdim_q"] + hdim_v = sig["hdim_v"] + divisor = global_rules.get("hdim_divisible_by", 8) + if hdim_q % divisor != 0 or hdim_v % divisor != 0: + result.add_error(f"Head dimensions must be multiples of {divisor}") + + if global_rules.get("hdim_192_128_no_bias_dropout"): + if ( + hdim_q == 192 + and hdim_v == 128 + and (bias != "no" or sig.get("dropout", False)) + ): + result.add_warning( + "hdim (192,128) with bias/dropout has limited tile support" + ) + + if global_rules.get("logits_requires_no_bias"): + if bias != "no" and sig.get("logits", False): + result.add_error("logits_soft_cap cannot be combined with bias") + + if pipeline in {"qr_async_trload", "v3", "qr_async_trload_v3"} and ( + hdim_q != hdim_v or hdim_q not in {64, 128} + ): + result.add_error(f"{pipeline} only supports symmetric head dims 64 or 128") + + # Tile validation + tile = alg["tile"] + expected_tile_len = 9 if family == "bwd_dq_dk_dv" else 6 + if len(tile) != expected_tile_len or len(alg["wave"]) != 9 or len(alg["warp"]) != 9: + result.add_error( + f"tile/wave/warp must have {expected_tile_len}/9/9 elements for {family}" + ) + + # MFMA instruction count check for qr/h256/CDNA + _1d_families = {"bwd_dot_do_o", "bwd_convert_dq"} + if ( + pipeline == "qr" + and hdim_q == 256 + and family not in _1d_families + and arch_info.get("family", "").startswith("cdna") + and len(tile) >= 3 + and len(alg["wave"]) >= 2 + and len(alg["warp"]) >= 3 + ): + wm, wn, wk = alg["warp"][0], alg["warp"][1], alg["warp"][2] + gm, gn = alg["wave"][0], alg["wave"][1] + if wm > 0 and wn > 0 and wk > 0 and gm > 0 and gn > 0: + num_mfma = (tile[0] // wm) * (tile[1] // wn) * (tile[2] // wk) // (gm * gn) + if num_mfma % 8 != 0: + result.add_error( + f"NumMfmaInsts={num_mfma} must be divisible by 8 for qr/h256/CDNA" + ) + + if alg["block_per_cu"] <= 0 and alg["block_per_cu"] != -1: + result.add_error("block_per_cu must be positive or -1 (auto)") + if alg["num_wave_groups"] <= 0: + result.add_error("num_wave_groups must be positive") + + # --- Family-specific rules --- + if family == "batch_prefill": + if sig.get("vlayout", "r") != "r": + result.add_error("batch_prefill only supports row-major V layout") + if not sig.get("paged_kv", False): + result.add_error("batch_prefill requires paged_kv=true") + ps = sig.get("page_size", 0) + if ps <= 0 or (ps & (ps - 1)) != 0: + result.add_error("batch_prefill page_size must be a positive power of two") + if sig.get("mode", "batch") != "group": + result.add_error("batch_prefill requires group mode") + if pipeline != "qr_async": + result.add_error("batch_prefill currently uses qr_async pipeline") + + if family == "fwd_appendkv": + if sig.get("mode", "batch") != "batch": + result.add_error("fwd_appendkv uses batch-mode public API surface") + if pipeline != "appendkv": + result.add_error("fwd_appendkv must use appendkv pipeline") + if sig.get("vlayout", "r") != "r": + result.add_error("fwd_appendkv currently only supports row-major V") + + if family == "fwd_splitkv_combine": + if sig.get("mode", "batch") not in {"batch", "group"}: + result.add_error("fwd_splitkv_combine requires batch or group mode") + combine_bn1 = arch_specs.get("splitkv_combine", {}).get("combine_bn1", 32) + if len(tile) > 3 and tile[3] != combine_bn1: + result.add_error(f"fwd_splitkv_combine requires bn1={combine_bn1}") + if len(tile) > 3 and (hdim_v < tile[3] or hdim_v % tile[3] != 0): + result.add_error("fwd_splitkv_combine requires hdim_v divisible by bn1") + + if family == "fwd_pagedkv": + if pipeline != "qr_pagedkv": + result.add_error("fwd_pagedkv must use qr_pagedkv pipeline") + if not sig.get("paged_kv", False): + result.add_error("fwd_pagedkv requires paged_kv=true") + if sig.get("vlayout", "r") != "r": + result.add_error("fwd_pagedkv currently only supports row-major V") + + if family == "fwd_splitkv": + if pipeline not in {"qr", "qr_nwarp_sshuffle"}: + result.add_error("fwd_splitkv must use qr or qr_nwarp_sshuffle pipeline") + if sig.get("vlayout", "r") != "r": + result.add_error("fwd_splitkv currently only supports row-major V") + + if family == "fwd" and sig.get("vlayout", "r") != "r": + result.add_warning("dispatcher forward examples currently assume row-major V") + + return result diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt index ab094e90cf..f95b5e627b 100644 --- a/dispatcher/examples/CMakeLists.txt +++ b/dispatcher/examples/CMakeLists.txt @@ -290,7 +290,7 @@ function(add_declarative_gpu_example NAME SOURCE) COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py ${EXAMPLE_SOURCE} --output-dir ${EXAMPLE_KERNEL_DIR} - --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include" + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include,${CMAKE_CURRENT_SOURCE_DIR}/../.." --gpu-target ${GPU_TARGET} --jobs ${NPROC} --target-name ${NAME} @@ -456,7 +456,47 @@ add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bw add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp) # ============================================================================= -# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D) +# FMHA C++ Examples +# ============================================================================= + +add_declarative_gpu_example(fmha_01_basic fmha/cpp/01_basic_fmha.cpp) +add_declarative_gpu_example(fmha_02_splitkv fmha/cpp/02_splitkv_fmha.cpp) +add_declarative_gpu_example(fmha_03_kvcache fmha/cpp/03_kvcache_fmha.cpp) +add_declarative_gpu_example(fmha_04_bwd fmha/cpp/04_bwd_fmha.cpp) +add_declarative_gpu_example(fmha_05_appendkv fmha/cpp/05_appendkv_fmha.cpp) +add_declarative_gpu_example(fmha_06_batch_prefill fmha/cpp/06_batch_prefill_fmha.cpp) +add_declarative_gpu_example(fmha_07_profile_pytorch fmha/cpp/07_profile_pytorch_fmha.cpp) +add_declarative_gpu_example(fmha_08_profile_flash fmha/cpp/08_profile_flash_fmha.cpp) +add_declarative_gpu_example(fmha_09_profile_aiter fmha/cpp/09_profile_aiter_fmha.cpp) +add_declarative_gpu_example(fmha_10_profile_fp32_fp8 fmha/cpp/10_profile_fp32_fp8_fmha.cpp) +add_declarative_gpu_example(fmha_11_receipt_aliases fmha/cpp/11_receipt_aliases_fmha.cpp) +add_declarative_gpu_example(fmha_12_registry_json fmha/cpp/12_registry_json_fmha.cpp) +add_declarative_gpu_example(fmha_13_feature_coverage fmha/cpp/13_feature_coverage_fmha.cpp) +add_declarative_gpu_example(fmha_14_benchmark_validation fmha/cpp/14_benchmark_validation_fmha.cpp) +add_declarative_gpu_example(fmha_15_multi_shape fmha/cpp/15_multi_shape_fmha.cpp) +add_declarative_gpu_example(fmha_16_heuristics fmha/cpp/16_heuristics_fmha.cpp) +add_declarative_gpu_example(fmha_17_autofill_autocorrect fmha/cpp/17_autofill_autocorrect_fmha.cpp) +add_declarative_gpu_example(fmha_18_gpu_splitkv fmha/cpp/18_gpu_splitkv_fmha.cpp) +add_declarative_gpu_example(fmha_19_gpu_masks fmha/cpp/19_gpu_masks_fmha.cpp) +add_declarative_gpu_example(fmha_20_gpu_bias fmha/cpp/20_gpu_bias_fmha.cpp) +add_declarative_gpu_example(fmha_21_gpu_features fmha/cpp/21_gpu_features_fmha.cpp) +add_declarative_gpu_example(fmha_22_gpu_bwd fmha/cpp/22_gpu_bwd_fmha.cpp) +add_declarative_gpu_example(fmha_23_multi_registry fmha/cpp/23_multi_registry_fmha.cpp) +add_declarative_gpu_example(fmha_24_per_receipt_registries fmha/cpp/24_per_receipt_registries_fmha.cpp) +add_declarative_gpu_example(fmha_25_gpu_appendkv_prefill fmha/cpp/25_gpu_appendkv_batchprefill_fmha.cpp) +add_declarative_gpu_example(fmha_26_dtypes_hdims fmha/cpp/26_dtypes_hdims_fmha.cpp) +add_declarative_gpu_example(fmha_27_padding_permutation fmha/cpp/27_padding_permutation_fmha.cpp) +add_declarative_gpu_example(fmha_28_bwd_masks fmha/cpp/28_bwd_masks_fmha.cpp) +add_declarative_gpu_example(fmha_29_bwd_bias_dropout fmha/cpp/29_bwd_bias_dropout_fmha.cpp) +add_declarative_gpu_example(fmha_30_bwd_benchmark fmha/cpp/30_bwd_benchmark_fmha.cpp) +add_declarative_gpu_example(fmha_31_logits_soft_cap fmha/cpp/31_logits_soft_cap_fmha.cpp) +add_declarative_gpu_example(fmha_32_sink_tokens fmha/cpp/32_sink_tokens_fmha.cpp) +add_declarative_gpu_example(fmha_33_bwd_deterministic fmha/cpp/33_bwd_deterministic_fmha.cpp) +add_declarative_gpu_example(fmha_34_bwd_gqa fmha/cpp/34_bwd_gqa_fmha.cpp) +add_declarative_gpu_example(fmha_35_generic_mask fmha/cpp/35_generic_mask_fmha.cpp) + +# ============================================================================= +# Grouped Convolution Python Library - Multi-Kernel (fwd/bwdd/bwdw x 2D/3D) # ============================================================================= # Kernel output directory for the Python conv library @@ -502,13 +542,67 @@ if(hip_FOUND) endif() add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels) +# ============================================================================= +# FMHA Python Library - Single Fallback Kernel +# ============================================================================= + +set(FMHA_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/fmha_python_fallback") +set(FMHA_DISPATCH_HEADER "${FMHA_FALLBACK_KERNEL_DIR}/fmha_python_dispatch.hpp") +set(FMHA_FALLBACK_LIB "${FMHA_FALLBACK_KERNEL_DIR}/libfmha_python_fallback.a") +set(FMHA_FALLBACK_SENTINEL "${FMHA_FALLBACK_KERNEL_DIR}/.generated") + +# Generate the FMHA fallback kernel, compile it, and produce both +# the dispatch header and a static library with the kernel object. +# Uses example_kernel_builder.py with a synthetic source that declares +# a single FMHA kernel set, just like the C++ examples do. +set(FMHA_FALLBACK_SOURCE "${FMHA_FALLBACK_KERNEL_DIR}/fmha_python_fallback.cpp") +add_custom_command( + OUTPUT ${FMHA_DISPATCH_HEADER} ${FMHA_FALLBACK_LIB} ${FMHA_FALLBACK_SENTINEL} + COMMAND ${CMAKE_COMMAND} -E make_directory ${FMHA_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/fmha/generate_fallback.py + --output-dir ${FMHA_FALLBACK_KERNEL_DIR} + --gpu-target ${GPU_TARGET} + --compile + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include:${CMAKE_CURRENT_SOURCE_DIR}/../include:${CMAKE_CURRENT_SOURCE_DIR}/../.." + COMMAND ${CMAKE_COMMAND} -E touch ${FMHA_FALLBACK_SENTINEL} + COMMENT "Generating and compiling FMHA fallback kernel for Python library..." + VERBATIM +) + +add_custom_target(generate_fmha_fallback_kernels + DEPENDS ${FMHA_DISPATCH_HEADER} ${FMHA_FALLBACK_LIB}) + +# FMHA dynamic library for Python +add_library(dispatcher_fmha_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/fmha_ctypes_lib.cpp) +target_link_libraries(dispatcher_fmha_lib PRIVATE ck_tile_dispatcher ${FMHA_FALLBACK_LIB}) +target_include_directories(dispatcher_fmha_lib PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../.. + ${FMHA_FALLBACK_KERNEL_DIR} + ${FMHA_FALLBACK_KERNEL_DIR}/dispatcher_wrappers +) +target_compile_options(dispatcher_fmha_lib PRIVATE + -include ${FMHA_DISPATCH_HEADER} + -DGFX_ARCH="${GPU_TARGET}" + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +if(hip_FOUND) + target_link_libraries(dispatcher_fmha_lib PRIVATE hip::device hip::host) +endif() +add_dependencies(dispatcher_fmha_lib generate_fmha_fallback_kernels) + message(STATUS "GEMM examples configured - kernels will be generated during 'make'") message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'") +message(STATUS "FMHA examples configured - kernels will be generated during 'make'") # Convenience target to build all Python ctypes libraries add_custom_target(python_libs - DEPENDS dispatcher_gemm_lib dispatcher_conv_lib - COMMENT "Building Python ctypes libraries (GEMM + Conv)" + DEPENDS dispatcher_gemm_lib dispatcher_conv_lib dispatcher_fmha_lib + COMMENT "Building Python ctypes libraries (GEMM + Conv + FMHA)" ) # ============================================================================= diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md index 24bea821ba..a5a8253558 100644 --- a/dispatcher/examples/README.md +++ b/dispatcher/examples/README.md @@ -59,9 +59,17 @@ python3 examples/gemm/python/08_heuristics.py ``` examples/ |---- gemm/ -| |---- cpp/ # 6 C++ GEMM examples +| |---- cpp/ # 7 C++ GEMM examples | +---- python/ # 11 Python GEMM examples | +|---- grouped_conv/ +| |---- cpp/ # 7 C++ Grouped Conv examples +| +---- python/ # 6 Python Grouped Conv examples +| +|---- fmha/ +| |---- cpp/ # 35 C++ FMHA examples (all variants) +| +---- python/ # 38 Python FMHA examples (JIT-compiled) +| +---- README.md ``` diff --git a/dispatcher/examples/fmha/cpp/01_basic_fmha.cpp b/dispatcher/examples/fmha/cpp/01_basic_fmha.cpp new file mode 100644 index 0000000000..0045da3a0a --- /dev/null +++ b/dispatcher/examples/fmha/cpp/01_basic_fmha.cpp @@ -0,0 +1,371 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 01: Basic FMHA Forward with GPU Execution +// +// Demonstrates the full flow: +// 1. Declare kernels via DECL_FMHA_KERNEL_SET +// 2. Register and plan +// 3. Allocate Q, K, V, O GPU buffers +// 4. Run the FMHA forward kernel on GPU +// 5. Copy output to host and validate against CPU reference +// +// Mirrors 01_basic_gemm.cpp for FMHA. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +// FMHA tile/wave/warp dimensions correspond to TWO GEMM stages: +// Stage 0 (Q * K^T): tile_m0 x tile_n0 x tile_k0 (seqlen_q x seqlen_k x hdim_q) +// Stage 1 (Attn * V): tile_m0 x tile_n1 x tile_k1 (seqlen_q x hdim_v x seqlen_k) +// Wave/warp follow the same stage pattern: *_m0/n0/k0 for stage 0, *_m1/n1/k1 for stage 1. +DECL_FMHA_KERNEL_SET(basic_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") // V row-major + .hdim(128) // hdim_q = hdim_v = 128 + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 tile: seqlen_q=128, seqlen_k=128, hdim_q=32 + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 tile: hdim_v=128, seqlen_k=32, alignment=128 + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + // Wave: 4 warps on m, 1 on n, 1 on k (both stages) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + // Warp tile: 32x32x16 (both stages) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) // pad_s, pad_sk, pad_d, pad_dv + .alignments(128, 128) // hdim_q_alignment, hdim_v_alignment + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 01: FMHA Forward (GPU Execution)", "FMHA with real GPU data"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length (Q and K)"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_flag("--validate", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 01: FMHA Forward (GPU Execution)"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("basic_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Step 2: Plan + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + const int64_t q_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t k_elems = q_elems; + const int64_t v_elems = q_elems; + const int64_t o_elems = q_elems; + + // Step 3: Allocate GPU buffers + std::cout << "\nStep 2: Allocate GPU Buffers\n"; + std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim + << "]\n"; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(k_elems); + GpuBuffer v_dev(v_elems); + GpuBuffer o_dev(o_elems); + + // Fill Q, K, V with random data + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems); + std::vector k_host(k_elems); + std::vector v_host(v_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + // Step 4: Set up args with device pointers and strides + fmha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = nullptr; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.lse_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.scale_s = scale; + fmha_args.logits_soft_cap = 0.0f; + + // bhsd layout strides + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = 0; + fmha_args.stride_randval = 0; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = seqlen * hdim; + fmha_args.nhead_stride_k = seqlen * hdim; + fmha_args.nhead_stride_v = seqlen * hdim; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_randval = 0; + fmha_args.nhead_stride_lse = 0; + fmha_args.nhead_stride_o = seqlen * hdim; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + + fmha_args.batch_stride_q = nhead * seqlen * hdim; + fmha_args.batch_stride_k = nhead * seqlen * hdim; + fmha_args.batch_stride_v = nhead * seqlen * hdim; + fmha_args.batch_stride_bias = 0; + fmha_args.batch_stride_randval = 0; + fmha_args.batch_stride_lse = 0; + fmha_args.batch_stride_o = nhead * seqlen * hdim; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = -1; + fmha_args.sink_size = 0; + fmha_args.mask_type = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.p_drop = 0.0f; + fmha_args.s_randval = false; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + // Step 5: Run on GPU + std::cout << "\nStep 3: Run FMHA Forward on GPU\n"; + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + return 1; + } + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 6: Copy output and validate + std::cout << "\nStep 4: Validate\n"; + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + // Quick sanity check: output should be non-zero + int nonzero = 0; + for(int64_t i = 0; i < o_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n"; + + bool passed = (nonzero > 0); + + if(args.has("--validate")) + { + // CPU reference + std::vector q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < v_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + double max_abs_err = 0.0; + double max_rel_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Max rel error: " << max_rel_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + passed = (errors == 0); + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp b/dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp new file mode 100644 index 0000000000..d9dc852b6e --- /dev/null +++ b/dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp @@ -0,0 +1,162 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(splitkv_fmha_kernels, + .add(FmhaSignature() + .family("fwd_splitkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .paged_kv(false), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true) + .max_splits_log2(6) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("fwd_splitkv_combine") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .paged_kv(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(32) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true) + .max_splits_log2(6) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 02: FMHA Split-KV", "Declarative FMHA split-KV planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "1", "Batch size"); + args.add_option("--nhead", "16", "Number of heads"); + args.add_option("--seqlen", "128", "Query sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + { + return 0; + } + + utils::print_header("Example 02: FMHA Split-KV"); + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 1); + const int nhead = args.get_int("--nhead", 16); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + registry.set_name("splitkv_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + + // Step 2: Plan + std::cout << "\nStep 2: Plan\n"; + + fmha_fwd_splitkv_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = true; + traits.do_fp8_static_quant = false; + traits.has_sink = false; + + fmha_fwd_splitkv_args fmha_args{}; + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = 2048; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.num_splits = 8; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + auto plan = dispatcher.plan(problem); + + if(!plan.is_valid() || plan.stages.size() != 2) + { + std::cerr << "Expected a two-stage split-KV plan\n"; + return 1; + } + + // Step 3: Results + std::cout << "\nStep 3: Results\n"; + for(const auto& stage : plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + utils::print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp b/dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp new file mode 100644 index 0000000000..c3632a7d2f --- /dev/null +++ b/dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(kvcache_fmha_kernels, + .add(FmhaSignature() + .family("fwd_pagedkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no") + .paged_kv(true) + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_pagedkv") + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("fwd_appendkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .rope("inter") + .paged_kv(true) + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(64) + .tile_k0(128) + .tile_n1(128) + .tile_k1(0) + .tile_k0max(0) + .pipeline("appendkv") + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("batch_prefill") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .paged_kv(true) + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 03: FMHA KV-Cache", "Declarative FMHA KV-cache planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "1", "Batch size"); + args.add_option("--nhead", "16", "Number of heads"); + args.add_option("--seqlen", "128", "Prefill query sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + { + return 0; + } + + utils::print_header("Example 03: FMHA KV-Cache"); + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 1); + const int nhead = args.get_int("--nhead", 16); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + + // Step 2: Plan PagedKV (decode) + std::cout << "\nStep 2: Plan PagedKV (decode)\n"; + + fmha_fwd_pagedkv_traits paged_traits{}; + paged_traits.hdim_q = hdim; + paged_traits.hdim_v = hdim; + paged_traits.data_type = "fp16"; + paged_traits.is_group_mode = false; + paged_traits.is_v_rowmajor = true; + paged_traits.mask_type = mask_enum::no_mask; + paged_traits.bias_type = bias_enum::no_bias; + paged_traits.use_pagedkv = true; + + fmha_fwd_pagedkv_args paged_args{}; + paged_args.seqlen_q = 1; + paged_args.seqlen_k = 1024; + paged_args.batch = batch; + paged_args.max_seqlen_q = 1; + paged_args.hdim_q = hdim; + paged_args.hdim_v = hdim; + paged_args.nhead_q = nhead; + paged_args.nhead_k = nhead; + paged_args.block_table_ptr = reinterpret_cast(0x1); + paged_args.page_block_size = 16; + + auto paged_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(paged_traits, paged_args), gfx_arch)); + + // Step 3: Plan AppendKV + std::cout << "\nStep 3: Plan AppendKV\n"; + + fmha_fwd_appendkv_traits append_traits{}; + append_traits.hdim_q = hdim; + append_traits.hdim_v = hdim; + append_traits.data_type = "fp16"; + append_traits.is_v_rowmajor = true; + append_traits.rope_type = rope_enum::interleaved; + + fmha_fwd_appendkv_args append_args{}; + append_args.seqlen_q = 1; + append_args.seqlen_knew = 1; + append_args.batch = batch; + append_args.hdim_q = hdim; + append_args.hdim_v = hdim; + append_args.nhead_q = nhead; + append_args.nhead_k = nhead; + append_args.rotary_dim = hdim; + append_args.rotary_cos_ptr = reinterpret_cast(0x1); + append_args.rotary_sin_ptr = reinterpret_cast(0x1); + append_args.block_table_ptr = reinterpret_cast(0x1); + append_args.page_block_size = 16; + + auto append_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(append_traits, append_args), gfx_arch)); + + // Step 4: Plan BatchPrefill + std::cout << "\nStep 4: Plan BatchPrefill\n"; + + fmha_batch_prefill_traits prefill_traits{}; + prefill_traits.hdim_q = hdim; + prefill_traits.hdim_v = hdim; + prefill_traits.data_type = "fp16"; + prefill_traits.is_group_mode = true; + prefill_traits.is_v_rowmajor = true; + prefill_traits.mask_type = mask_enum::no_mask; + prefill_traits.bias_type = bias_enum::no_bias; + prefill_traits.has_lse = true; + prefill_traits.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + prefill_traits.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + prefill_traits.page_size = 16; + + fmha_batch_prefill_args prefill_args{}; + prefill_args.batch = batch; + prefill_args.seqlen_q = seqlen; + prefill_args.seqlen_k = 1024; + prefill_args.max_seqlen_q = seqlen; + prefill_args.hdim_q = hdim; + prefill_args.hdim_v = hdim; + prefill_args.nhead_q = nhead; + prefill_args.nhead_k = nhead; + prefill_args.num_total_pages = 64; + prefill_args.page_block_size = 16; + prefill_args.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + prefill_args.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + prefill_args.kv_indptr = reinterpret_cast(0x1); + prefill_args.kv_page_indices = reinterpret_cast(0x1); + prefill_args.kv_last_page_lens = reinterpret_cast(0x1); + prefill_args.seqstart_q_ptr = reinterpret_cast(0x1); + + auto prefill_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch)); + + // Step 5: Results + std::cout << "\nStep 5: Results\n"; + std::cout << " PagedKV stages: " << paged_plan.stages.size() << "\n"; + std::cout << " AppendKV stages: " << append_plan.stages.size() << "\n"; + std::cout << " BatchPrefill stages: " << prefill_plan.stages.size() << "\n"; + + utils::print_separator(); + return (paged_plan.is_valid() && append_plan.is_valid() && prefill_plan.is_valid()) ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp b/dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp new file mode 100644 index 0000000000..05d08f4a0d --- /dev/null +++ b/dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(bwd_fmha_kernels, + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 04: FMHA Backward", "Declarative FMHA backward planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "1", "Batch size"); + args.add_option("--nhead", "16", "Number of heads"); + args.add_option("--seqlen", "128", "Sequence length (Q and K)"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + { + return 0; + } + + utils::print_header("Example 04: FMHA Backward"); + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 1); + const int nhead = args.get_int("--nhead", 16); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + + // Step 2: Plan + std::cout << "\nStep 2: Plan\n"; + + fmha_bwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_dbias = false; + traits.has_dropout = false; + traits.is_store_randval = false; + traits.is_deterministic = false; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = batch; + bwd_args.seqlen_q = seqlen; + bwd_args.seqlen_k = seqlen; + bwd_args.max_seqlen_q = seqlen; + bwd_args.max_seqlen_k = seqlen; + bwd_args.hdim_q = hdim; + bwd_args.hdim_v = hdim; + bwd_args.nhead_q = nhead; + bwd_args.nhead_k = nhead; + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, bwd_args), gfx_arch)); + + if(!plan.is_valid() || plan.stages.size() < 2) + { + std::cerr << "Expected a multi-stage backward plan\n"; + return 1; + } + + // Step 3: Results + std::cout << "\nStep 3: Results\n"; + for(const auto& stage : plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + utils::print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp b/dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp new file mode 100644 index 0000000000..7bd95642f0 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(appendkv_fmha_kernels, + .add(FmhaSignature() + .family("fwd_appendkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .rope("inter") + .paged_kv(true) + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(64) + .tile_k0(128) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(0) + .tile_k0max(0) + .pipeline("appendkv") + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 05: FMHA AppendKV", "Declarative FMHA append-KV planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "1", "Batch size"); + args.add_option("--nhead", "16", "Number of heads"); + args.add_option("--seqlen", "1", "Sequence length (tokens to append)"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + { + return 0; + } + + utils::print_header("Example 05: FMHA AppendKV"); + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 1); + const int nhead = args.get_int("--nhead", 16); + const int seqlen = args.get_int("--seqlen", 1); + const int hdim = args.get_int("--hdim", 128); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + + // Step 2: Plan + std::cout << "\nStep 2: Plan\n"; + + fmha_fwd_appendkv_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_v_rowmajor = true; + traits.rope_type = rope_enum::interleaved; + + fmha_fwd_appendkv_args fmha_args{}; + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_knew = seqlen; + fmha_args.batch = batch; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.rotary_dim = hdim; + fmha_args.rotary_cos_ptr = reinterpret_cast(0x1); + fmha_args.rotary_sin_ptr = reinterpret_cast(0x1); + fmha_args.block_table_ptr = reinterpret_cast(0x1); + fmha_args.page_block_size = 16; + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + + if(!plan.is_valid() || plan.stages.size() != 1) + { + std::cerr << "Expected a single-stage append-KV plan\n"; + return 1; + } + + // Step 3: Results + std::cout << "\nStep 3: Results\n"; + for(const auto& stage : plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + utils::print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp b/dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp new file mode 100644 index 0000000000..148a6433e9 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp @@ -0,0 +1,133 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(batch_prefill_fmha_kernels, + .add(FmhaSignature() + .family("batch_prefill") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .paged_kv(true) + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 06: FMHA Batch Prefill", + "Declarative FMHA batch-prefill planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "1", "Batch size"); + args.add_option("--nhead", "16", "Number of heads"); + args.add_option("--seqlen", "128", "Query sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + { + return 0; + } + + utils::print_header("Example 06: FMHA Batch Prefill"); + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 1); + const int nhead = args.get_int("--nhead", 16); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + + // Step 2: Plan + std::cout << "\nStep 2: Plan\n"; + + fmha_batch_prefill_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = true; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = true; + traits.kv_memory_layout = ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + traits.kv_lookup_table = ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + traits.page_size = 16; + + fmha_batch_prefill_args fmha_args{}; + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = 1024; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.num_total_pages = 64; + fmha_args.page_block_size = 16; + fmha_args.kv_memory_layout = ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + fmha_args.kv_lookup_table = ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + fmha_args.kv_indptr = reinterpret_cast(0x1); + fmha_args.kv_page_indices = reinterpret_cast(0x1); + fmha_args.kv_last_page_lens = reinterpret_cast(0x1); + fmha_args.seqstart_q_ptr = reinterpret_cast(0x1); + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + + if(!plan.is_valid() || plan.stages.size() != 1) + { + std::cerr << "Expected a single-stage batch-prefill plan\n"; + return 1; + } + + // Step 3: Results + std::cout << "\nStep 3: Results\n"; + for(const auto& stage : plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + utils::print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp b/dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp new file mode 100644 index 0000000000..3859dc68dd --- /dev/null +++ b/dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp @@ -0,0 +1,248 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(pytorch_profile_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("bias") + .profile("pytorch"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(32) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd_splitkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .profile("pytorch"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true) + .max_splits_log2(6), + "gfx950") + .add(FmhaSignature() + .family("fwd_splitkv_combine") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .profile("pytorch"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(32) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true) + .max_splits_log2(6), + "gfx950") + .add(FmhaSignature() + .family("fwd_appendkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .profile("pytorch"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(64) + .tile_k0(128) + .tile_n1(128) + .tile_k1(0) + .tile_k0max(0) + .padding(false, true, true, false) + .pipeline("appendkv"), + "gfx950") + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .profile("pytorch"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .profile("pytorch"), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .profile("pytorch"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 07: PyTorch-Profile FMHA", + "Declarative FMHA PyTorch-profile planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + if(!args.parse(argc, argv)) + { + return 0; + } + + const std::string gfx_arch = args.get("--arch", "gfx950"); + + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + FmhaDispatcher dispatcher(®istry); + + std::cout << "PyTorch-profile FMHA kernels: " << registry.size() << "\n"; + + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = 128; + fwd_traits.hdim_v = 128; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.mask_type = mask_enum::no_mask; + fwd_traits.bias_type = bias_enum::elementwise_bias; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.batch = 1; + fwd_args.seqlen_q = 128; + fwd_args.seqlen_k = 128; + fwd_args.max_seqlen_q = 128; + fwd_args.hdim_q = 128; + fwd_args.hdim_v = 128; + fwd_args.nhead_q = 16; + fwd_args.nhead_k = 16; + + auto fwd_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch)); + + fmha_bwd_traits bwd_traits{}; + bwd_traits.hdim_q = 128; + bwd_traits.hdim_v = 128; + bwd_traits.data_type = "fp16"; + bwd_traits.is_group_mode = false; + bwd_traits.mask_type = mask_enum::no_mask; + bwd_traits.bias_type = bias_enum::no_bias; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = 1; + bwd_args.seqlen_q = 128; + bwd_args.seqlen_k = 128; + bwd_args.max_seqlen_q = 128; + bwd_args.max_seqlen_k = 128; + bwd_args.hdim_q = 128; + bwd_args.hdim_v = 128; + bwd_args.nhead_q = 16; + bwd_args.nhead_k = 16; + + auto bwd_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch)); + + std::cout << "Forward plan stages: " << fwd_plan.stages.size() << "\n"; + std::cout << "Backward plan stages: " << bwd_plan.stages.size() << "\n"; + return (fwd_plan.is_valid() && bwd_plan.is_valid()) ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp b/dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp new file mode 100644 index 0000000000..3b4e3b276d --- /dev/null +++ b/dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp @@ -0,0 +1,165 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(flash_profile_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("alibi") + .profile("flash_fwd"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(32) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .profile("flash_bwd"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .profile("flash_bwd"), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .profile("flash_bwd"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 08: Flash-Profile FMHA", + "Declarative FMHA Flash-profile planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + if(!args.parse(argc, argv)) + { + return 0; + } + + const std::string gfx_arch = args.get("--arch", "gfx950"); + + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + FmhaDispatcher dispatcher(®istry); + + std::cout << "Flash-profile FMHA kernels: " << registry.size() << "\n"; + + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = 128; + fwd_traits.hdim_v = 128; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.mask_type = mask_enum::no_mask; + fwd_traits.bias_type = bias_enum::alibi; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.batch = 1; + fwd_args.seqlen_q = 128; + fwd_args.seqlen_k = 128; + fwd_args.max_seqlen_q = 128; + fwd_args.hdim_q = 128; + fwd_args.hdim_v = 128; + fwd_args.nhead_q = 16; + fwd_args.nhead_k = 16; + + auto fwd_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch)); + + fmha_bwd_traits bwd_traits{}; + bwd_traits.hdim_q = 128; + bwd_traits.hdim_v = 128; + bwd_traits.data_type = "fp16"; + bwd_traits.is_group_mode = false; + bwd_traits.mask_type = mask_enum::no_mask; + bwd_traits.bias_type = bias_enum::no_bias; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = 1; + bwd_args.seqlen_q = 128; + bwd_args.seqlen_k = 128; + bwd_args.max_seqlen_q = 128; + bwd_args.max_seqlen_k = 128; + bwd_args.hdim_q = 128; + bwd_args.hdim_v = 128; + bwd_args.nhead_q = 16; + bwd_args.nhead_k = 16; + + auto bwd_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch)); + + std::cout << "Flash fwd stages: " << fwd_plan.stages.size() << "\n"; + std::cout << "Flash bwd stages: " << bwd_plan.stages.size() << "\n"; + return (fwd_plan.is_valid() && bwd_plan.is_valid()) ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp b/dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp new file mode 100644 index 0000000000..7d61e38636 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp @@ -0,0 +1,212 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET( + aiter_profile_fmha_kernels, + .add(FmhaSignature().family("fwd").dtype("fp16").mode("batch").vlayout("r").hdim(128).profile( + "aiter_batch"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .profile("aiter_group"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd_pagedkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .paged_kv(true) + .profile("aiter_cpp") + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_pagedkv") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("batch_prefill") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .paged_kv(true) + .profile("aiter_cpp") + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 09: AITER-Profile FMHA", + "Declarative FMHA AITER-profile planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + if(!args.parse(argc, argv)) + { + return 0; + } + + const std::string gfx_arch = args.get("--arch", "gfx950"); + + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + FmhaDispatcher dispatcher(®istry); + + std::cout << "AITER-profile FMHA kernels: " << registry.size() << "\n"; + + fmha_fwd_traits batch_traits{}; + batch_traits.hdim_q = 128; + batch_traits.hdim_v = 128; + batch_traits.data_type = "fp16"; + batch_traits.is_group_mode = false; + batch_traits.is_v_rowmajor = true; + batch_traits.mask_type = mask_enum::no_mask; + batch_traits.bias_type = bias_enum::no_bias; + batch_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args batch_args{}; + batch_args.batch = 1; + batch_args.seqlen_q = 128; + batch_args.seqlen_k = 128; + batch_args.max_seqlen_q = 128; + batch_args.hdim_q = 128; + batch_args.hdim_v = 128; + batch_args.nhead_q = 16; + batch_args.nhead_k = 16; + + auto batch_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(batch_traits, batch_args), gfx_arch)); + + fmha_batch_prefill_traits prefill_traits{}; + prefill_traits.hdim_q = 128; + prefill_traits.hdim_v = 128; + prefill_traits.data_type = "fp16"; + prefill_traits.is_group_mode = true; + prefill_traits.is_v_rowmajor = true; + prefill_traits.mask_type = mask_enum::no_mask; + prefill_traits.bias_type = bias_enum::no_bias; + prefill_traits.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + prefill_traits.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + prefill_traits.page_size = 16; + + fmha_batch_prefill_args prefill_args{}; + prefill_args.batch = 1; + prefill_args.seqlen_q = 128; + prefill_args.seqlen_k = 1024; + prefill_args.max_seqlen_q = 128; + prefill_args.hdim_q = 128; + prefill_args.hdim_v = 128; + prefill_args.nhead_q = 16; + prefill_args.nhead_k = 16; + prefill_args.num_total_pages = 64; + prefill_args.page_block_size = 16; + prefill_args.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + prefill_args.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + prefill_args.kv_indptr = reinterpret_cast(0x1); + prefill_args.kv_page_indices = reinterpret_cast(0x1); + prefill_args.kv_last_page_lens = reinterpret_cast(0x1); + prefill_args.seqstart_q_ptr = reinterpret_cast(0x1); + + auto prefill_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch)); + + std::cout << "AITER batch stages: " << batch_plan.stages.size() << "\n"; + std::cout << "AITER prefill stages: " << prefill_plan.stages.size() << "\n"; + return (batch_plan.is_valid() && prefill_plan.is_valid()) ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp b/dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp new file mode 100644 index 0000000000..60d476df5f --- /dev/null +++ b/dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp @@ -0,0 +1,152 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(fp32_fp8_profile_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp32") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .profile("fp32_min"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(32) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(16) + .tile_k0max(128) + .wave_m0(2) + .wave_n0(1) + .wave_k0(1) + .wave_m1(2) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp32") + .mode("batch") + .vlayout("r") + .hdim(48) + .mask("no") + .bias("no") + .profile("fp32_all"), + FmhaAlgorithm() + .tile_m0(32) + .tile_n0(128) + .tile_k0(16) + .tile_n1(48) + .tile_k1(16) + .tile_k0max(48) + .wave_m0(2) + .wave_n0(1) + .wave_k0(1) + .wave_m1(2) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp8bf16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .profile("fp8_test"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(32) + .warp_m1(32) + .warp_n1(32) + .warp_k1(32) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 10: FP32/FP8-Profile FMHA", + "Declarative FMHA FP32/FP8-profile planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + if(!args.parse(argc, argv)) + { + return 0; + } + + const std::string gfx_arch = args.get("--arch", "gfx950"); + + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + FmhaDispatcher dispatcher(®istry); + + std::cout << "FP32/FP8-profile FMHA kernels: " << registry.size() << "\n"; + std::cout << registry.export_json(false) << "\n"; + + fmha_fwd_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp32"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.batch = 1; + fmha_args.seqlen_q = 128; + fmha_args.seqlen_k = 128; + fmha_args.max_seqlen_q = 128; + fmha_args.hdim_q = 128; + fmha_args.hdim_v = 128; + fmha_args.nhead_q = 16; + fmha_args.nhead_k = 16; + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + + std::cout << "FP32/FP8-profile plan stages: " << plan.stages.size() << "\n"; + return plan.is_valid() ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp b/dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp new file mode 100644 index 0000000000..3110e8c851 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp @@ -0,0 +1,176 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(receipt_alias_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .bias("alibi") + .receipt(2), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(32) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .bias("bias") + .receipt(4), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(32) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .receipt(100), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp32") + .mode("batch") + .vlayout("r") + .hdim(128) + .receipt(800), + FmhaAlgorithm() + .tile_m0(32) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(16) + .tile_k0max(128) + .wave_m0(2) + .wave_n0(1) + .wave_k0(1) + .wave_m1(2) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 11: Receipt Aliases FMHA", + "Declarative FMHA receipt-alias planning"); + args.add_option("--arch", "gfx950", "GPU architecture"); + if(!args.parse(argc, argv)) + { + return 0; + } + + const std::string gfx_arch = args.get("--arch", "gfx950"); + + FmhaRegistry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + FmhaDispatcher dispatcher(®istry); + + std::cout << "Receipt-alias FMHA kernels: " << registry.size() << "\n"; + + fmha_fwd_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.batch = 1; + fmha_args.seqlen_q = 128; + fmha_args.seqlen_k = 128; + fmha_args.max_seqlen_q = 128; + fmha_args.hdim_q = 128; + fmha_args.hdim_v = 128; + fmha_args.nhead_q = 16; + fmha_args.nhead_k = 16; + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + + std::cout << "Receipt-alias plan stages: " << plan.stages.size() << "\n"; + return plan.is_valid() ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp b/dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp new file mode 100644 index 0000000000..a1c27efd2c --- /dev/null +++ b/dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp @@ -0,0 +1,129 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET( + registry_json_fmha_kernels, + .add(FmhaSignature().family("fwd").dtype("fp16").mode("batch").vlayout("r").hdim(128), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature() + .family("fwd_pagedkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .paged_kv(true) + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_pagedkv") + .padding(true, true, true, true), + "gfx950") + .add(FmhaSignature().family("bwd_dq_dk_dv").dtype("fp16").mode("batch").hdim(128), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 12: Registry JSON FMHA", + "Declarative FMHA registry JSON export"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--output", "", "Write JSON to file (optional)"); + if(!args.parse(argc, argv)) + { + return 0; + } + + utils::print_header("Example 12: Registry JSON FMHA"); + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const std::string output_path = args.get("--output", ""); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + registry.set_name("registry_json_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // Step 2: Export JSON + std::cout << "\nStep 2: Export JSON\n"; + std::string json = registry.export_json(true); + std::cout << " JSON size: " << json.size() << " bytes\n"; + std::cout << json.substr(0, std::min(json.size(), 240)) << "\n"; + + // Step 3: Write to file (if --output specified) + if(!output_path.empty()) + { + std::cout << "\nStep 3: Write to File\n"; + std::ofstream ofs(output_path); + if(!ofs.is_open()) + { + std::cerr << " ERROR: Cannot open " << output_path << " for writing\n"; + return 1; + } + ofs << json; + ofs.close(); + std::cout << " Written to: " << output_path << "\n"; + std::cout << " File size: " << json.size() << " bytes\n"; + } + + utils::print_separator(); + return registry.size() > 0 ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/13_feature_coverage_fmha.cpp b/dispatcher/examples/fmha/cpp/13_feature_coverage_fmha.cpp new file mode 100644 index 0000000000..53e66db609 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/13_feature_coverage_fmha.cpp @@ -0,0 +1,499 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 13: FMHA Feature Coverage +// Exercises every feature dimension from the 01_fmha smoke test: +// bf16, masks (top-left, bottom-right, window_generic), GQA, dropout, +// multiple hdims (64, 256), group mode, col-major V. + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(feature_coverage_kernels, + // fp16 forward (basic, needed for GQA and other fp16 tests) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // bf16 forward + .add(FmhaSignature() + .family("fwd") + .dtype("bf16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // hdim 64 + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(64) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(64) + .tile_k0(32) + .tile_n1(64) + .tile_k1(32) + .tile_k0max(64) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(64, 64) + .selection_rank(0), + "gfx950") + + // hdim 256 + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(256) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(256) + .tile_k1(32) + .tile_k0max(256) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr") + .padding(false, false, false, false) + .alignments(256, 256) + .selection_rank(0), + "gfx950") + + // Mask: causal (top-left and bottom-right share the same compiled kernel; + // the mask type is resolved at runtime via the args, not the template) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("top_left") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Dropout + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(true) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // GQA (nhead_q != nhead_k) - same kernel, GQA is a runtime concern + // Bias: elementwise + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("bias") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Bias: alibi + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("alibi") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Group mode + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Sink tokens + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("top_left") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no") + .sink(true), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +struct FeatureTest +{ + std::string name; + FmhaProblem problem; +}; + +FeatureTest make_test(const std::string& name, + const std::string& dtype, + int hdim_q, + int hdim_v, + int mask, + int bias, + bool lse, + bool dropout, + bool group, + bool logits, + bool sink, + int nhead_q = 16, + int nhead_k = 16, + const std::string& arch = "gfx950") +{ + auto p = FmhaProblemBuilder() + .api_family(FmhaApiFamily::Fwd) + .kernel_family(FmhaKernelFamily::Fwd) + .gfx_arch(arch) + .data_type(dtype) + .dims(hdim_q, hdim_v, 2, 128, 256) + .nheads(nhead_q, nhead_k) + .mask_type(mask) + .bias_type(bias) + .lse(lse) + .dropout(dropout) + .group_mode(group) + .logits_soft_cap(logits) + .sink(sink) + .build(); + return {name, p}; +} + +} // namespace + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 13: FMHA Feature Coverage", + "Tests all 01_fmha smoke test features"); + args.add_option("--arch", "gfx950", "GPU architecture"); + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 13: FMHA Feature Coverage"); + + const std::string gfx_arch = args.get("--arch", "gfx950"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("feature_coverage"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + + // Step 2: Run feature tests + std::cout << "\nStep 2: Run Feature Tests\n"; + std::vector tests = { + make_test("bf16_basic", "bf16", 128, 128, 0, 0, false, false, false, false, false), + make_test("fp16_hdim64", "fp16", 64, 64, 0, 0, false, false, false, false, false), + make_test("fp16_hdim256", "fp16", 256, 256, 0, 0, true, false, false, false, false), + make_test("mask_top_left", "fp16", 128, 128, 1, 0, false, false, false, false, false), + make_test("mask_bottom_right", "fp16", 128, 128, 2, 0, false, false, false, false, false), + make_test("dropout", "fp16", 128, 128, 0, 0, true, true, false, false, false), + make_test("gqa_h16_hk4", "fp16", 128, 128, 0, 0, false, false, false, false, false, 16, 4), + make_test("bias_elementwise", "fp16", 128, 128, 0, 1, false, false, false, false, false), + make_test("bias_alibi", "fp16", 128, 128, 0, 2, false, false, false, false, false), + make_test("group_mode", "fp16", 128, 128, 0, 0, false, false, true, false, false), + make_test("sink_tokens", "fp16", 128, 128, 1, 0, false, false, false, false, true), + }; + + int pass = 0; + int fail = 0; + for(const auto& test : tests) + { + auto plan = dispatcher.plan(test.problem); + bool ok = plan.is_valid(); + std::cout << (ok ? "[PASS]" : "[FAIL]") << " " << test.name; + if(ok) + { + std::cout << " -> " << plan.stages[0].kernel_id; + ++pass; + } + else + { + ++fail; + } + std::cout << "\n"; + } + + // Step 3: Summary + std::cout << "\nStep 3: Summary\n"; + std::cout << " " << pass << " passed, " << fail << " failed out of " << tests.size() << "\n"; + + utils::print_separator(); + return fail > 0 ? 1 : 0; +} diff --git a/dispatcher/examples/fmha/cpp/14_benchmark_validation_fmha.cpp b/dispatcher/examples/fmha/cpp/14_benchmark_validation_fmha.cpp new file mode 100644 index 0000000000..412ede3979 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/14_benchmark_validation_fmha.cpp @@ -0,0 +1,404 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 14: FMHA Benchmark with Validation +// +// Demonstrates: +// 1. Warmup runs to stabilize GPU clocks +// 2. Repeated benchmark runs with statistics (min/avg/max/median) +// 3. Optional CPU reference validation via --verify flag +// +// Usage: +// ./14_benchmark_validation_fmha +// ./14_benchmark_validation_fmha --seqlen 256 --batch 4 --repeat 20 +// ./14_benchmark_validation_fmha --verify + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +using FmhaDataType = ck_tile::fp16_t; + +DECL_FMHA_KERNEL_SET(benchmark_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 14: FMHA Benchmark + Validation", + "Warmup, repeated benchmark, optional verification"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "8", "Number of heads"); + args.add_option("--seqlen", "128", "Sequence length (Q and K)"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_option("--warmup", "3", "Warmup iterations"); + args.add_option("--repeat", "10", "Benchmark repetitions"); + args.add_flag("--verify", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 8); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + const int warmup = args.get_int("--warmup", 3); + const int repeat = args.get_int("--repeat", 10); + + print_header("Example 14: FMHA Benchmark + Validation"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("benchmark_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + const int64_t q_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t o_elems = q_elems; + + // Step 2: Allocate GPU buffers + std::cout << "\nStep 2: Allocate GPU Buffers\n"; + std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim + << "]\n"; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(q_elems); + GpuBuffer v_dev(q_elems); + GpuBuffer o_dev(o_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems); + std::vector k_host(q_elems); + std::vector v_host(q_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + + fmha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = nullptr; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.lse_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.scale_s = scale; + fmha_args.logits_soft_cap = 0.0f; + + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = 0; + fmha_args.stride_randval = 0; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = seqlen * hdim; + fmha_args.nhead_stride_k = seqlen * hdim; + fmha_args.nhead_stride_v = seqlen * hdim; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_randval = 0; + fmha_args.nhead_stride_lse = 0; + fmha_args.nhead_stride_o = seqlen * hdim; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + + fmha_args.batch_stride_q = nhead * seqlen * hdim; + fmha_args.batch_stride_k = nhead * seqlen * hdim; + fmha_args.batch_stride_v = nhead * seqlen * hdim; + fmha_args.batch_stride_bias = 0; + fmha_args.batch_stride_randval = 0; + fmha_args.batch_stride_lse = 0; + fmha_args.batch_stride_o = nhead * seqlen * hdim; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = -1; + fmha_args.sink_size = 0; + fmha_args.mask_type = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.p_drop = 0.0f; + fmha_args.s_randval = false; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + FmhaDispatcher dispatcher(®istry); + + // Step 3: Warmup runs + std::cout << "\nStep 3: Warmup (" << warmup << " iterations)\n"; + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 1); + for(int i = 0; i < warmup; ++i) + { + o_dev.zero(); + float t = dispatcher.run_fwd(traits, fmha_args, nullptr); + std::cout << " Warmup " << (i + 1) << ": " << std::fixed << std::setprecision(4) << t + << " ms\n"; + } + + // Step 4: Benchmark runs + std::cout << "\nStep 4: Benchmark (" << repeat << " iterations)\n"; + dispatcher.set_timing(0, 1); + std::vector times; + times.reserve(repeat); + + for(int i = 0; i < repeat; ++i) + { + o_dev.zero(); + float t = dispatcher.run_fwd(traits, fmha_args, nullptr); + times.push_back(t); + } + + std::sort(times.begin(), times.end()); + float t_min = times.front(); + float t_max = times.back(); + float t_avg = std::accumulate(times.begin(), times.end(), 0.0f) / static_cast(repeat); + float t_med = + (repeat % 2 == 0) ? (times[repeat / 2 - 1] + times[repeat / 2]) / 2.0f : times[repeat / 2]; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double ops = static_cast(problem.num_ops()); + double tflops_min = ops / (t_max * 1e-3) / 1e12; + double tflops_max = ops / (t_min * 1e-3) / 1e12; + double tflops_avg = ops / (t_avg * 1e-3) / 1e12; + double tflops_med = ops / (t_med * 1e-3) / 1e12; + + std::cout << "\n " << std::setw(10) << "Metric" << " | " << std::setw(12) << "Time(ms)" + << " | " << std::setw(12) << "TFLOPS" << "\n"; + std::cout << " " << std::string(40, '-') << "\n"; + std::cout << std::fixed << std::setprecision(4); + std::cout << " " << std::setw(10) << "Min" << " | " << std::setw(12) << t_min << " | " + << std::setprecision(2) << std::setw(12) << tflops_max << "\n"; + std::cout << std::setprecision(4); + std::cout << " " << std::setw(10) << "Avg" << " | " << std::setw(12) << t_avg << " | " + << std::setprecision(2) << std::setw(12) << tflops_avg << "\n"; + std::cout << std::setprecision(4); + std::cout << " " << std::setw(10) << "Median" << " | " << std::setw(12) << t_med << " | " + << std::setprecision(2) << std::setw(12) << tflops_med << "\n"; + std::cout << std::setprecision(4); + std::cout << " " << std::setw(10) << "Max" << " | " << std::setw(12) << t_max << " | " + << std::setprecision(2) << std::setw(12) << tflops_min << "\n"; + + bool passed = true; + + // Step 5: Optional validation + if(args.has("--verify")) + { + std::cout << "\nStep 5: CPU Reference Validation\n"; + + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + std::vector q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(o_elems, 0.0f); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < q_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < q_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + double max_abs_err = 0.0; + double max_rel_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Max rel error: " << max_rel_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + passed = (errors == 0); + } + else + { + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + int nonzero = 0; + for(int64_t i = 0; i < o_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << "\n Sanity: " << nonzero << " / " << o_elems << " non-zero outputs\n"; + passed = (nonzero > 0); + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp b/dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp new file mode 100644 index 0000000000..99b4974f08 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp @@ -0,0 +1,282 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 15: Multi-Shape FMHA Sweep +// +// Demonstrates running a single FMHA kernel across multiple (batch, seqlen) +// combinations, producing a performance table. This pattern is useful for +// characterizing kernel behavior across the parameter space. +// +// Usage: +// ./15_multi_shape_fmha +// ./15_multi_shape_fmha --arch gfx942 + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +using FmhaDataType = ck_tile::fp16_t; + +DECL_FMHA_KERNEL_SET(multi_shape_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +struct ShapeConfig +{ + int batch; + int seqlen; +}; + +const ShapeConfig SHAPES[] = { + {1, 64}, + {1, 128}, + {1, 256}, + {1, 512}, + {2, 64}, + {2, 128}, + {2, 256}, + {4, 64}, + {4, 128}, +}; + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 15: Multi-Shape FMHA", + "Sweep (batch, seqlen) combos with a single kernel"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--nhead", "8", "Number of heads"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int nhead = args.get_int("--nhead", 8); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 15: Multi-Shape FMHA Sweep"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("multi_shape_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Step 2: Sweep shapes + std::cout << "\nStep 2: Shape Sweep (nhead=" << nhead << ", hdim=" << hdim << ")\n\n"; + + std::cout << " " << std::setw(6) << "Batch" << " | " << std::setw(8) << "SeqLen" << " | " + << std::setw(12) << "Elements" << " | " << std::setw(10) << "Time(ms)" << " | " + << std::setw(10) << "TFLOPS" << " | " << std::setw(8) << "Status" << "\n"; + std::cout << " " << std::string(66, '-') << "\n"; + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + int pass_count = 0; + int total = 0; + const int num_shapes = sizeof(SHAPES) / sizeof(SHAPES[0]); + + for(int si = 0; si < num_shapes; ++si) + { + const auto& shape = SHAPES[si]; + ++total; + + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + const int64_t elems = static_cast(shape.batch) * nhead * shape.seqlen * hdim; + + GpuBuffer q_dev(elems); + GpuBuffer k_dev(elems); + GpuBuffer v_dev(elems); + GpuBuffer o_dev(elems); + + std::vector h_buf(elems); + for(auto& x : h_buf) + x = FmhaDataType(dist(rng)); + q_dev.copy_from_host(h_buf.data()); + for(auto& x : h_buf) + x = FmhaDataType(dist(rng)); + k_dev.copy_from_host(h_buf.data()); + for(auto& x : h_buf) + x = FmhaDataType(dist(rng)); + v_dev.copy_from_host(h_buf.data()); + o_dev.zero(); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = nullptr; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.lse_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + + fmha_args.seqlen_q = shape.seqlen; + fmha_args.seqlen_k = shape.seqlen; + fmha_args.batch = shape.batch; + fmha_args.max_seqlen_q = shape.seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.scale_s = scale; + fmha_args.logits_soft_cap = 0.0f; + + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = 0; + fmha_args.stride_randval = 0; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = shape.seqlen * hdim; + fmha_args.nhead_stride_k = shape.seqlen * hdim; + fmha_args.nhead_stride_v = shape.seqlen * hdim; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_randval = 0; + fmha_args.nhead_stride_lse = 0; + fmha_args.nhead_stride_o = shape.seqlen * hdim; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + + fmha_args.batch_stride_q = nhead * shape.seqlen * hdim; + fmha_args.batch_stride_k = nhead * shape.seqlen * hdim; + fmha_args.batch_stride_v = nhead * shape.seqlen * hdim; + fmha_args.batch_stride_bias = 0; + fmha_args.batch_stride_randval = 0; + fmha_args.batch_stride_lse = 0; + fmha_args.batch_stride_o = nhead * shape.seqlen * hdim; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = -1; + fmha_args.sink_size = 0; + fmha_args.mask_type = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.p_drop = 0.0f; + fmha_args.s_randval = false; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + bool ok = false; + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + + std::vector o_host(elems); + o_dev.copy_to_host(o_host.data()); + int nonzero = 0; + for(int64_t i = 0; i < elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + ok = (nonzero > 0); + } + catch(const std::exception& e) + { + std::cerr << " ERROR for B=" << shape.batch << " S=" << shape.seqlen << ": " + << e.what() << "\n"; + } + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << std::fixed; + std::cout << " " << std::setw(6) << shape.batch << " | " << std::setw(8) << shape.seqlen + << " | " << std::setw(12) << elems << " | " << std::setprecision(4) + << std::setw(10) << time_ms << " | " << std::setprecision(2) << std::setw(10) + << tflops << " | " << std::setw(8) << (ok ? "PASS" : "FAIL") << "\n"; + + if(ok) + ++pass_count; + } + + // Summary + print_separator(); + std::cout << "Results: " << pass_count << "/" << total << " shapes passed\n"; + std::cout << "Status: " << (pass_count == total ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return (pass_count == total) ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/16_heuristics_fmha.cpp b/dispatcher/examples/fmha/cpp/16_heuristics_fmha.cpp new file mode 100644 index 0000000000..b3f6db2031 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/16_heuristics_fmha.cpp @@ -0,0 +1,428 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 16: FMHA Heuristic-Based Kernel Selection +// +// Demonstrates: +// 1. Two kernels with different tile_m0 (128 vs 64) and selection_rank +// 2. Custom heuristic function that picks kernels based on seqlen +// 3. dispatcher.set_heuristic() + SelectionStrategy::Heuristic +// 4. Planning different problems to show which kernel is selected +// 5. GPU execution for at least one problem +// +// Usage: +// ./16_heuristics_fmha +// ./16_heuristics_fmha --arch gfx942 + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +using FmhaDataType = ck_tile::fp16_t; + +DECL_FMHA_KERNEL_SET(heuristic_fmha_kernels, + // Kernel A: Large tile (128x128) -- better for long sequences + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + // Kernel B: Smaller tile_m0 (64x128) -- lower latency for short sequences + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(1), + "gfx950")); + +namespace { + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 16: FMHA Heuristic Kernel Selection", + "Custom heuristic picks kernel based on seqlen"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--nhead", "8", "Number of heads"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int nhead = args.get_int("--nhead", 8); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 16: FMHA Heuristic Kernel Selection"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("heuristic_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // Step 2: Set up heuristic + std::cout << "\nStep 2: Configure Heuristic\n"; + std::cout << " Rule: seqlen >= 256 -> prefer large tile (128x128, rank=0)\n"; + std::cout << " seqlen < 256 -> prefer small tile (64x128, rank=1)\n"; + + auto all_kernels = registry.all_kernels(); + std::cout << " Available kernels:\n"; + for(const auto& k : all_kernels) + { + std::cout << " - " << k->id() << "\n"; + } + + std::string kernel_a_id, kernel_b_id; + for(const auto& k : all_kernels) + { + auto kid = k->id(); + if(kernel_a_id.empty()) + kernel_a_id = kid; + else if(kernel_b_id.empty()) + kernel_b_id = kid; + } + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + dispatcher.set_heuristic([&](const FmhaProblem& problem) -> std::vector { + if(problem.seqlen_q >= 256) + return {kernel_a_id, kernel_b_id}; + else + return {kernel_b_id, kernel_a_id}; + }); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Step 3: Plan different problems to show kernel selection + std::cout << "\nStep 3: Plan Problems (show kernel selection)\n\n"; + + struct PlanCase + { + int batch; + int seqlen; + }; + PlanCase plan_cases[] = {{1, 64}, {1, 128}, {2, 256}, {2, 512}, {4, 1024}}; + + std::cout << " " << std::setw(6) << "Batch" << " | " << std::setw(8) << "SeqLen" << " | " + << std::setw(50) << "Selected Kernel" << "\n"; + std::cout << " " << std::string(68, '-') << "\n"; + + for(const auto& pc : plan_cases) + { + auto problem = FmhaProblemBuilder() + .api_family(FmhaApiFamily::Fwd) + .kernel_family(FmhaKernelFamily::Fwd) + .gfx_arch(gfx_arch) + .data_type("fp16") + .dims(hdim, hdim, pc.batch, pc.seqlen, pc.seqlen) + .nheads(nhead, nhead) + .mask_type(0) + .bias_type(0) + .lse(false) + .dropout(false) + .build(); + + auto plan = dispatcher.plan(problem); + std::string selected = plan.is_valid() ? plan.stages[0].kernel_id : "(no match)"; + std::cout << " " << std::setw(6) << pc.batch << " | " << std::setw(8) << pc.seqlen << " | " + << std::setw(50) << selected << "\n"; + } + + // Step 4: GPU execution for a representative problem + std::cout << "\nStep 4: GPU Execution (batch=2, seqlen=256)\n"; + + const int batch = 2; + const int seqlen = 256; + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + const int64_t elems = static_cast(batch) * nhead * seqlen * hdim; + + GpuBuffer q_dev(elems); + GpuBuffer k_dev(elems); + GpuBuffer v_dev(elems); + GpuBuffer o_dev(elems); + + std::mt19937 rng(42); + std::uniform_real_distribution fdist(-0.5f, 0.5f); + + std::vector q_host(elems), k_host(elems), v_host(elems); + for(auto& x : q_host) + x = FmhaDataType(fdist(rng)); + for(auto& x : k_host) + x = FmhaDataType(fdist(rng)); + for(auto& x : v_host) + x = FmhaDataType(fdist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = nullptr; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.lse_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.scale_s = scale; + fmha_args.logits_soft_cap = 0.0f; + + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = 0; + fmha_args.stride_randval = 0; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = seqlen * hdim; + fmha_args.nhead_stride_k = seqlen * hdim; + fmha_args.nhead_stride_v = seqlen * hdim; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_randval = 0; + fmha_args.nhead_stride_lse = 0; + fmha_args.nhead_stride_o = seqlen * hdim; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + + fmha_args.batch_stride_q = nhead * seqlen * hdim; + fmha_args.batch_stride_k = nhead * seqlen * hdim; + fmha_args.batch_stride_v = nhead * seqlen * hdim; + fmha_args.batch_stride_bias = 0; + fmha_args.batch_stride_randval = 0; + fmha_args.batch_stride_lse = 0; + fmha_args.batch_stride_o = nhead * seqlen * hdim; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = -1; + fmha_args.sink_size = 0; + fmha_args.mask_type = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.p_drop = 0.0f; + fmha_args.s_randval = false; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + float time_ms = 0.0f; + bool passed = false; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validate against CPU reference + std::vector o_host(elems); + o_dev.copy_to_host(o_host.data()); + + std::vector q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f); + for(int64_t i = 0; i < elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + double max_abs_err = 0.0; + int errors = 0; + for(int64_t i = 0; i < elems; ++i) + { + double abs_err = std::abs(static_cast(o_host[i]) - o_ref[i]); + max_abs_err = std::max(max_abs_err, abs_err); + if(abs_err > 1e-2 + 1e-2 * std::abs(o_ref[i])) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Errors: " << errors << " / " << elems << "\n"; + passed = (errors == 0); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/17_autofill_autocorrect_fmha.cpp b/dispatcher/examples/fmha/cpp/17_autofill_autocorrect_fmha.cpp new file mode 100644 index 0000000000..2b21dcd9fe --- /dev/null +++ b/dispatcher/examples/fmha/cpp/17_autofill_autocorrect_fmha.cpp @@ -0,0 +1,423 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 17: FMHA Autofill and Autocorrect +// +// Demonstrates three DECL_FMHA_KERNEL_SET patterns: +// 1. AUTOFILL: Minimal specification -- only family/dtype/hdim/pipeline/tile +// are provided; wave/warp use defaults from FmhaAlgorithm constructor +// 2. AUTOCORRECT: Intentionally non-standard wave config that still works +// because FmhaAlgorithm auto_fill() corrects missing tile_n1/tile_k1 +// 3. FULL: All parameters explicitly specified (reference) +// +// Each is registered, planned, run on GPU, and validated. +// +// Usage: +// ./17_autofill_autocorrect_fmha +// ./17_autofill_autocorrect_fmha --arch gfx942 + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +using FmhaDataType = ck_tile::fp16_t; + +// Pattern 1: AUTOFILL -- minimal specification, defaults for wave/warp +DECL_FMHA_KERNEL_SET(autofill_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950")); + +// Pattern 2: AUTOCORRECT -- tile_n1/tile_k1 set to 0, auto_fill() corrects them +DECL_FMHA_KERNEL_SET(autocorrect_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true), + "gfx950")); + +// Pattern 3: FULL -- every parameter explicitly specified +DECL_FMHA_KERNEL_SET(full_spec_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +struct KernelTestCase +{ + std::string name; + std::string kernel_set_name; +}; + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 17: FMHA Autofill & Autocorrect", + "Three DECL_FMHA_KERNEL_SET patterns compared"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "8", "Number of heads"); + args.add_option("--seqlen", "128", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 8); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 17: FMHA Autofill & Autocorrect"); + + // Step 1: Show registered kernel sets + std::cout << "\nStep 1: Registered Kernel Sets\n"; + FmhaKernelSetRegistry::instance().print(); + + const KernelTestCase cases[] = { + {"AUTOFILL (minimal spec, wave/warp defaults)", "autofill_kernels"}, + {"AUTOCORRECT (tile_n1/k1=0, auto_fill corrects)", "autocorrect_kernels"}, + {"FULL (all params explicit)", "full_spec_kernels"}, + }; + + // Prepare input data (shared across all tests) + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + const int64_t elems = static_cast(batch) * nhead * seqlen * hdim; + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(elems), k_host(elems), v_host(elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + // CPU reference + std::vector q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f); + for(int64_t i = 0; i < elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < elems; ++i) + v_f32[i] = static_cast(v_host[i]); + cpu_attention_fwd(q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + int total_pass = 0; + const int total_cases = sizeof(cases) / sizeof(cases[0]); + + for(int ci = 0; ci < total_cases; ++ci) + { + const auto& tc = cases[ci]; + std::cout << "\nStep " << (ci + 2) << ": " << tc.name << "\n"; + + // Register from the named kernel set + FmhaRegistry registry; + registry.set_name(tc.kernel_set_name); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + if(registry.size() == 0) + { + std::cout << " SKIP: no kernels registered\n"; + continue; + } + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Allocate GPU buffers + GpuBuffer q_dev(elems); + GpuBuffer k_dev(elems); + GpuBuffer v_dev(elems); + GpuBuffer o_dev(elems); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = nullptr; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.lse_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.scale_s = scale; + fmha_args.logits_soft_cap = 0.0f; + + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = 0; + fmha_args.stride_randval = 0; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = seqlen * hdim; + fmha_args.nhead_stride_k = seqlen * hdim; + fmha_args.nhead_stride_v = seqlen * hdim; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_randval = 0; + fmha_args.nhead_stride_lse = 0; + fmha_args.nhead_stride_o = seqlen * hdim; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + + fmha_args.batch_stride_q = nhead * seqlen * hdim; + fmha_args.batch_stride_k = nhead * seqlen * hdim; + fmha_args.batch_stride_v = nhead * seqlen * hdim; + fmha_args.batch_stride_bias = 0; + fmha_args.batch_stride_randval = 0; + fmha_args.batch_stride_lse = 0; + fmha_args.batch_stride_o = nhead * seqlen * hdim; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = -1; + fmha_args.sink_size = 0; + fmha_args.mask_type = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.p_drop = 0.0f; + fmha_args.s_randval = false; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + try + { + float time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + // Validate + std::vector o_host(elems); + o_dev.copy_to_host(o_host.data()); + + double max_abs_err = 0.0; + int errors = 0; + for(int64_t i = 0; i < elems; ++i) + { + double abs_err = std::abs(static_cast(o_host[i]) - o_ref[i]); + max_abs_err = std::max(max_abs_err, abs_err); + if(abs_err > 1e-2 + 1e-2 * std::abs(o_ref[i])) + ++errors; + } + + bool ok = (errors == 0); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms" + << " TFLOPS: " << std::setprecision(2) << tflops + << " MaxErr: " << std::scientific << max_abs_err << " " + << (ok ? "PASS" : "FAIL") << "\n"; + if(ok) + ++total_pass; + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + } + } + + // Summary + print_separator(); + std::cout << "Results: " << total_pass << "/" << total_cases << " patterns passed\n"; + std::cout << "Patterns:\n"; + std::cout << " 1. AUTOFILL: Only tile + pipeline specified; wave/warp use defaults\n"; + std::cout << " 2. AUTOCORRECT: tile_n1/k1/k0max=0 -> auto_fill() infers from tile_n0/k0\n"; + std::cout << " 3. FULL: Every parameter explicit (reference configuration)\n"; + std::cout << "Status: " << (total_pass == total_cases ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return (total_pass == total_cases) ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/18_gpu_splitkv_fmha.cpp b/dispatcher/examples/fmha/cpp/18_gpu_splitkv_fmha.cpp new file mode 100644 index 0000000000..26c5564277 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/18_gpu_splitkv_fmha.cpp @@ -0,0 +1,466 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 18: GPU Split-KV FMHA Forward +// +// Demonstrates split-KV attention with GPU execution: +// 1. Declare both fwd_splitkv and fwd_splitkv_combine kernels +// 2. Show 2-stage execution plan +// 3. Allocate Q, K, V, O plus workspace (lse_acc, o_acc) +// 4. Run the split-KV forward pass on GPU +// 5. Copy output to host and validate against CPU reference + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(splitkv_gpu_kernels, + .add(FmhaSignature() + .family("fwd_splitkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .paged_kv(false), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr_nwarp_sshuffle") + .padding(true, true, true, true) + .max_splits_log2(6) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("fwd_splitkv_combine") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .paged_kv(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(32) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(16) + .warp_n0(16) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr") + .padding(true, true, true, true) + .max_splits_log2(6) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 18: GPU Split-KV FMHA Forward", "Split-KV with GPU execution"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen_q", "64", "Query sequence length"); + args.add_option("--seqlen_k", "2048", "KV sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_option("--splits", "2", "Number of KV splits"); + args.add_flag("--validate", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen_q = args.get_int("--seqlen_q", 64); + const int seqlen_k = args.get_int("--seqlen_k", 2048); + const int hdim = args.get_int("--hdim", 128); + const int num_splits = args.get_int("--splits", 2); + + print_header("Example 18: GPU Split-KV FMHA Forward"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("splitkv_gpu_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Step 2: Set up traits and plan + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + fmha_fwd_splitkv_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = true; + traits.do_fp8_static_quant = false; + traits.has_sink = false; + + // Workspace sizes: lse_acc [batch, nhead, num_splits, seqlen_q] + // o_acc [batch, nhead, num_splits, seqlen_q, hdim] + const int64_t q_elems = static_cast(batch) * nhead * seqlen_q * hdim; + const int64_t k_elems = static_cast(batch) * nhead * seqlen_k * hdim; + const int64_t v_elems = k_elems; + const int64_t o_elems = q_elems; + const int64_t lse_elems = static_cast(batch) * nhead * seqlen_q; + const int64_t lse_acc_elems = static_cast(batch) * nhead * num_splits * seqlen_q; + const int64_t o_acc_elems = static_cast(batch) * nhead * num_splits * seqlen_q * hdim; + + // Show the 2-stage plan + std::cout << "\nStep 2: Plan (2-stage split-KV)\n"; + + fmha_fwd_splitkv_args plan_args{}; + plan_args.seqlen_q = seqlen_q; + plan_args.seqlen_k = seqlen_k; + plan_args.batch = batch; + plan_args.max_seqlen_q = seqlen_q; + plan_args.hdim_q = hdim; + plan_args.hdim_v = hdim; + plan_args.nhead_q = nhead; + plan_args.nhead_k = nhead; + plan_args.num_splits = num_splits; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, plan_args), gfx_arch); + auto plan = dispatcher.plan(problem); + + if(!plan.is_valid() || plan.stages.size() != 2) + { + std::cerr << " WARNING: Expected a two-stage split-KV plan, got " << plan.stages.size() + << " stage(s)\n"; + if(!plan.is_valid()) + { + std::cerr << " Plan is invalid -- no matching kernels found\n"; + print_separator(); + return 1; + } + } + + for(const auto& stage : plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + // Step 3: Allocate GPU buffers + std::cout << "\nStep 3: Allocate GPU Buffers\n"; + std::cout << " Q: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim + << "]\n"; + std::cout << " K/V: [" << batch << ", " << nhead << ", " << seqlen_k << ", " << hdim + << "]\n"; + std::cout << " O: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim + << "]\n"; + std::cout << " lse_acc: [" << batch << ", " << nhead << ", " << num_splits << ", " << seqlen_q + << "]\n"; + std::cout << " o_acc: [" << batch << ", " << nhead << ", " << num_splits << ", " << seqlen_q + << ", " << hdim << "]\n"; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(k_elems); + GpuBuffer v_dev(v_elems); + GpuBuffer o_dev(o_elems); + GpuBuffer lse_dev(lse_elems); + GpuBuffer lse_acc_dev(lse_acc_elems); + GpuBuffer o_acc_dev(o_acc_elems); + + // Fill Q, K, V with random data + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems); + std::vector k_host(k_elems); + std::vector v_host(v_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + lse_acc_dev.zero(); + o_acc_dev.zero(); + + // Step 4: Set up splitkv args with device pointers and strides + fmha_fwd_splitkv_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = nullptr; + fmha_args.lse_acc_ptr = lse_acc_dev.get(); + fmha_args.o_acc_ptr = o_acc_dev.get(); + fmha_args.lse_ptr = lse_dev.get(); + + fmha_args.block_table_ptr = nullptr; + fmha_args.batch_stride_block_table = 0; + fmha_args.page_block_size = 0; + fmha_args.is_gappy = false; + fmha_args.cache_batch_idx = nullptr; + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + + fmha_args.seqlen_q = seqlen_q; + fmha_args.seqlen_k = seqlen_k; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen_q; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.num_splits = num_splits; + + fmha_args.scale_s = scale; + fmha_args.scale_p = 1.0f; + fmha_args.scale_o = 1.0f; + fmha_args.logits_soft_cap = 0.0f; + + // bhsd layout strides + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = 0; + fmha_args.stride_o_acc = hdim; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = seqlen_q * hdim; + fmha_args.nhead_stride_k = seqlen_k * hdim; + fmha_args.nhead_stride_v = seqlen_k * hdim; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_lse = seqlen_q; + fmha_args.nhead_stride_lse_acc = num_splits * seqlen_q; + fmha_args.nhead_stride_o_acc = num_splits * seqlen_q * hdim; + fmha_args.nhead_stride_o = seqlen_q * hdim; + + fmha_args.batch_stride_q = nhead * seqlen_q * hdim; + fmha_args.batch_stride_k = nhead * seqlen_k * hdim; + fmha_args.batch_stride_v = nhead * seqlen_k * hdim; + fmha_args.batch_stride_bias = 0; + fmha_args.batch_stride_lse = nhead * seqlen_q; + fmha_args.batch_stride_lse_acc = nhead * num_splits * seqlen_q; + fmha_args.batch_stride_o_acc = nhead * num_splits * seqlen_q * hdim; + fmha_args.batch_stride_o = nhead * seqlen_q * hdim; + + fmha_args.split_stride_lse_acc = seqlen_q; + fmha_args.split_stride_o_acc = seqlen_q * hdim; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = -1; + fmha_args.sink_size = 0; + fmha_args.mask_type = 0; + + // Step 5: Run on GPU + std::cout << "\nStep 4: Run Split-KV FMHA Forward on GPU\n"; + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd_splitkv(traits, fmha_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " WARNING: GPU execution failed: " << e.what() << "\n"; + std::cerr << " Falling back to planning-only mode (split-KV compilation can be complex)\n"; + std::cout << "\n Plan summary (2 stages):\n"; + for(const auto& stage : plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + print_separator(); + std::cout << "Status: PLAN_ONLY\n"; + print_separator(); + return 0; + } + + auto run_problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double tflops = static_cast(run_problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 6: Copy output and validate + std::cout << "\nStep 5: Validate\n"; + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + int nonzero = 0; + for(int64_t i = 0; i < o_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n"; + + bool passed = (nonzero > 0); + + if(args.has("--validate")) + { + std::vector q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < v_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen_q, seqlen_k, hdim, hdim, scale); + + double max_abs_err = 0.0; + double max_rel_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Max rel error: " << max_rel_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + passed = (errors == 0); + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/19_gpu_masks_fmha.cpp b/dispatcher/examples/fmha/cpp/19_gpu_masks_fmha.cpp new file mode 100644 index 0000000000..d97e054e6e --- /dev/null +++ b/dispatcher/examples/fmha/cpp/19_gpu_masks_fmha.cpp @@ -0,0 +1,456 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 19: GPU FMHA Forward with Mask Types +// +// Demonstrates three mask variants with GPU execution: +// 1. No mask (standard attention) +// 2. Top-left causal mask (zero upper triangle) +// 3. Bottom-right causal mask (shifted diagonal) +// +// Uses seqlen_q=64, seqlen_k=128 to make mask behavior visible. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(mask_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("top_left") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + // Note: bottom_right shares the same compiled kernel as top_left + // (both use SimplifiedGenericAttentionMask). The mask type + // is resolved at runtime via args.mask_type, not the template. + // fmha_mask_compatible() in generated_fmha_backend.hpp handles this. +); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +// mask_type: 0=no_mask, 1=top_left, 2=bottom_right +void cpu_attention_fwd_masked(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + int mask_type) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + + bool masked = false; + if(mask_type == 1) + { + // top_left: causal from top-left, mask if sk >= sq + 1 + if(sk >= sq + 1) + masked = true; + } + else if(mask_type == 2) + { + // bottom_right: shifted diagonal, mask if sk >= sq + (seqlen_k - seqlen_q) + // + 1 + if(sk >= sq + (seqlen_k - seqlen_q) + 1) + masked = true; + } + + if(masked) + scores[sk] = -1e30f; + + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 19: FMHA with Masks (GPU)", "FMHA mask variants on GPU"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen_q", "64", "Query sequence length"); + args.add_option("--seqlen_k", "128", "KV sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_flag("--validate", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen_q = args.get_int("--seqlen_q", 64); + const int seqlen_k = args.get_int("--seqlen_k", 128); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 19: FMHA with Masks (GPU)"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("mask_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + // Allocate GPU buffers + const int64_t q_elems = static_cast(batch) * nhead * seqlen_q * hdim; + const int64_t k_elems = static_cast(batch) * nhead * seqlen_k * hdim; + const int64_t v_elems = k_elems; + const int64_t o_elems = q_elems; + + std::cout << "\nStep 2: Allocate GPU Buffers\n"; + std::cout << " Q/O: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim << "]\n"; + std::cout << " K/V: [" << batch << ", " << nhead << ", " << seqlen_k << ", " << hdim << "]\n"; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(k_elems); + GpuBuffer v_dev(v_elems); + GpuBuffer o_dev(o_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems); + std::vector k_host(k_elems); + std::vector v_host(v_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + + // Convert to f32 for CPU reference + std::vector q_f32(q_elems), k_f32(k_elems), v_f32(v_elems); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < v_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + // Test each mask type + struct MaskTest + { + const char* name; + int mask_type_int; + mask_enum mask_type; + }; + + MaskTest tests[] = { + {"no_mask", 0, mask_enum::no_mask}, + {"top_left", 1, mask_enum::mask_top_left}, + {"bottom_right", 2, mask_enum::mask_bottom_right}, + }; + + bool all_passed = true; + + for(const auto& test : tests) + { + std::cout << "\nStep 3: Run FMHA Forward [" << test.name << "]\n"; + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = test.mask_type; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + o_dev.zero(); + + fmha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = nullptr; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.lse_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + + fmha_args.seqlen_q = seqlen_q; + fmha_args.seqlen_k = seqlen_k; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen_q; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.scale_s = scale; + fmha_args.logits_soft_cap = 0.0f; + + // bhsd layout strides + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = 0; + fmha_args.stride_randval = 0; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = seqlen_q * hdim; + fmha_args.nhead_stride_k = seqlen_k * hdim; + fmha_args.nhead_stride_v = seqlen_k * hdim; + fmha_args.nhead_stride_bias = 0; + fmha_args.nhead_stride_randval = 0; + fmha_args.nhead_stride_lse = 0; + fmha_args.nhead_stride_o = seqlen_q * hdim; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + + fmha_args.batch_stride_q = nhead * seqlen_q * hdim; + fmha_args.batch_stride_k = nhead * seqlen_k * hdim; + fmha_args.batch_stride_v = nhead * seqlen_k * hdim; + fmha_args.batch_stride_bias = 0; + fmha_args.batch_stride_randval = 0; + fmha_args.batch_stride_lse = 0; + fmha_args.batch_stride_o = nhead * seqlen_q * hdim; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = (test.mask_type_int == 0) ? -1 : 0; + fmha_args.sink_size = 0; + fmha_args.mask_type = test.mask_type_int; + fmha_args.min_seqlen_q = 0; + fmha_args.p_drop = 0.0f; + fmha_args.s_randval = false; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " ERROR [" << test.name << "]: " << e.what() << "\n"; + all_passed = false; + continue; + } + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validate + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + int nonzero = 0; + for(int64_t i = 0; i < o_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n"; + + if(nonzero == 0) + all_passed = false; + + if(args.has("--validate")) + { + std::vector o_ref(o_elems, 0.0f); + cpu_attention_fwd_masked(q_f32, + k_f32, + v_f32, + o_ref, + batch, + nhead, + seqlen_q, + seqlen_k, + hdim, + hdim, + scale, + test.mask_type_int); + + double max_abs_err = 0.0; + double max_rel_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Max rel error: " << max_rel_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + if(errors > 0) + all_passed = false; + } + } + + print_separator(); + std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/20_gpu_bias_fmha.cpp b/dispatcher/examples/fmha/cpp/20_gpu_bias_fmha.cpp new file mode 100644 index 0000000000..d121abf657 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/20_gpu_bias_fmha.cpp @@ -0,0 +1,584 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 20: GPU FMHA Forward with Bias Types +// +// Demonstrates three bias variants with GPU execution: +// 1. No bias (standard attention) +// 2. Elementwise bias (arbitrary bias matrix added to scores) +// 3. ALiBi (Attention with Linear Biases -- slope-based positional encoding) +// +// Validates each variant against a CPU reference. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(bias_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("bias") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("alibi") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +// bias_type: 0=none, 1=elementwise, 2=alibi +// bias_buf layout: elementwise [1, nhead, seqlen_q, seqlen_k], alibi [1, nhead] slopes +void cpu_attention_fwd_biased(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + int bias_type, + const std::vector& bias_buf) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + float s = dot * scale; + + if(bias_type == 1) + { + int bias_idx = (h * seqlen_q + sq) * seqlen_k + sk; + s += bias_buf[bias_idx]; + } + else if(bias_type == 2) + { + float slope = bias_buf[h]; + s += slope * static_cast(sk - sq); + } + + scores[sk] = s; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 20: FMHA with Bias (GPU)", "FMHA bias variants on GPU"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length (Q and K)"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_flag("--validate", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 20: FMHA with Bias (GPU)"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("bias_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + // Allocate Q, K, V GPU buffers (shared across all bias tests) + const int64_t q_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t k_elems = q_elems; + const int64_t v_elems = q_elems; + const int64_t o_elems = q_elems; + + std::cout << "\nStep 2: Allocate GPU Buffers\n"; + std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim + << "]\n"; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(k_elems); + GpuBuffer v_dev(v_elems); + GpuBuffer o_dev(o_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems); + std::vector k_host(k_elems); + std::vector v_host(v_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + + // Convert to f32 for CPU reference + std::vector q_f32(q_elems), k_f32(k_elems), v_f32(v_elems); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < v_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + // Prepare elementwise bias buffer: [1, nhead, seqlen, seqlen] with small values + const int64_t elem_bias_elems = static_cast(nhead) * seqlen * seqlen; + std::vector elem_bias_host(elem_bias_elems); + std::uniform_real_distribution bias_dist(-0.1f, 0.1f); + for(auto& x : elem_bias_host) + x = bias_dist(rng); + + GpuBuffer elem_bias_dev(elem_bias_elems); + elem_bias_dev.copy_from_host(elem_bias_host.data()); + + // Prepare ALiBi slopes buffer: [nhead] with geometric slopes + std::vector alibi_slopes_host(nhead); + for(int h = 0; h < nhead; ++h) + { + alibi_slopes_host[h] = -std::pow(2.0f, -(8.0f * (h + 1) / nhead)); + } + + GpuBuffer alibi_slopes_dev(nhead); + alibi_slopes_dev.copy_from_host(alibi_slopes_host.data()); + + // Test each bias type + struct BiasTest + { + const char* name; + int bias_type_int; + bias_enum bias_type; + void* bias_ptr; + int stride_bias; + int nhead_stride_bias; + int batch_stride_bias; + }; + + BiasTest tests[] = { + {"no_bias", 0, bias_enum::no_bias, nullptr, 0, 0, 0}, + {"elementwise_bias", + 1, + bias_enum::elementwise_bias, + elem_bias_dev.get(), + seqlen, + seqlen * seqlen, + 0}, + {"alibi", 2, bias_enum::alibi, alibi_slopes_dev.get(), 0, 1, 0}, + }; + + bool all_passed = true; + + for(const auto& test : tests) + { + std::cout << "\nStep 3: Run FMHA Forward [" << test.name << "]\n"; + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = test.bias_type; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + o_dev.zero(); + + fmha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_dev.get(); + fmha_args.k_ptr = k_dev.get(); + fmha_args.v_ptr = v_dev.get(); + fmha_args.o_ptr = o_dev.get(); + + fmha_args.bias_ptr = test.bias_ptr; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.lse_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + fmha_args.scale_s = scale; + fmha_args.logits_soft_cap = 0.0f; + + // bhsd layout strides + fmha_args.stride_q = hdim; + fmha_args.stride_k = hdim; + fmha_args.stride_v = hdim; + fmha_args.stride_bias = test.stride_bias; + fmha_args.stride_randval = 0; + fmha_args.stride_o = hdim; + + fmha_args.nhead_stride_q = seqlen * hdim; + fmha_args.nhead_stride_k = seqlen * hdim; + fmha_args.nhead_stride_v = seqlen * hdim; + fmha_args.nhead_stride_bias = test.nhead_stride_bias; + fmha_args.nhead_stride_randval = 0; + fmha_args.nhead_stride_lse = 0; + fmha_args.nhead_stride_o = seqlen * hdim; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + + fmha_args.batch_stride_q = nhead * seqlen * hdim; + fmha_args.batch_stride_k = nhead * seqlen * hdim; + fmha_args.batch_stride_v = nhead * seqlen * hdim; + fmha_args.batch_stride_bias = test.batch_stride_bias; + fmha_args.batch_stride_randval = 0; + fmha_args.batch_stride_lse = 0; + fmha_args.batch_stride_o = nhead * seqlen * hdim; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.window_size_left = -1; + fmha_args.window_size_right = -1; + fmha_args.sink_size = 0; + fmha_args.mask_type = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.p_drop = 0.0f; + fmha_args.s_randval = false; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " ERROR [" << test.name << "]: " << e.what() << "\n"; + all_passed = false; + continue; + } + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validate + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + int nonzero = 0; + for(int64_t i = 0; i < o_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n"; + + if(nonzero == 0) + all_passed = false; + + if(args.has("--validate")) + { + std::vector o_ref(o_elems, 0.0f); + + if(test.bias_type_int == 0) + { + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + } + else + { + const std::vector& bias_ref = + (test.bias_type_int == 1) ? elem_bias_host : alibi_slopes_host; + cpu_attention_fwd_biased(q_f32, + k_f32, + v_f32, + o_ref, + batch, + nhead, + seqlen, + seqlen, + hdim, + hdim, + scale, + test.bias_type_int, + bias_ref); + } + + double max_abs_err = 0.0; + double max_rel_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Max rel error: " << max_rel_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + if(errors > 0) + all_passed = false; + } + } + + print_separator(); + std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/21_gpu_features_fmha.cpp b/dispatcher/examples/fmha/cpp/21_gpu_features_fmha.cpp new file mode 100644 index 0000000000..ff2893d9d8 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/21_gpu_features_fmha.cpp @@ -0,0 +1,697 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 21: GPU Features FMHA +// +// Tests multiple FMHA features with real GPU execution: +// 1. Dropout (with LSE, rand_val buffer) +// 2. GQA (nhead_q=16, nhead_k=4, same kernel) +// 3. LSE output (verify log-sum-exp values) +// +// Mirrors 01_basic_fmha.cpp for each feature variant. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(gpu_features_fmha_kernels, + // Basic fp16 kernel (used for GQA -- GQA is a runtime concern, same kernel) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Dropout kernel (requires LSE) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(true) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // LSE-only kernel + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + std::vector* lse_out = nullptr) +{ + const int nhead_ratio = nhead_q / nhead_k; + + for(int b = 0; b < batch; ++b) + { + for(int hq = 0; hq < nhead_q; ++hq) + { + const int hk = hq / nhead_ratio; + + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead_q + hq) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead_k + hk) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + if(lse_out) + { + int lse_idx = (b * nhead_q + hq) * seqlen_q + sq; + (*lse_out)[lse_idx] = max_score + std::log(sum_exp); + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead_k + hk) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead_q + hq) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +struct FeatureResult +{ + std::string name; + bool passed; + float time_ms; +}; + +fmha_fwd_args make_base_args(void* q, + void* k, + void* v, + void* o, + int batch, + int nhead_q, + int nhead_k, + int seqlen, + int hdim, + float scale) +{ + fmha_fwd_args a{}; + a.q_ptr = q; + a.k_ptr = k; + a.v_ptr = v; + a.o_ptr = o; + + a.bias_ptr = nullptr; + a.q_descale_ptr = nullptr; + a.k_descale_ptr = nullptr; + a.v_descale_ptr = nullptr; + a.rand_val_ptr = nullptr; + a.lse_ptr = nullptr; + a.sink_ptr = nullptr; + a.block_scale_seqstart_q_ptr = nullptr; + a.block_scale_seqstart_k_ptr = nullptr; + + a.seqlen_q = seqlen; + a.seqlen_k = seqlen; + a.batch = batch; + a.max_seqlen_q = seqlen; + a.hdim_q = hdim; + a.hdim_v = hdim; + a.nhead_q = nhead_q; + a.nhead_k = nhead_k; + a.scale_s = scale; + a.logits_soft_cap = 0.0f; + + a.stride_q = hdim; + a.stride_k = hdim; + a.stride_v = hdim; + a.stride_bias = 0; + a.stride_randval = 0; + a.stride_o = hdim; + + a.nhead_stride_q = seqlen * hdim; + a.nhead_stride_k = seqlen * hdim; + a.nhead_stride_v = seqlen * hdim; + a.nhead_stride_bias = 0; + a.nhead_stride_randval = 0; + a.nhead_stride_lse = 0; + a.nhead_stride_o = seqlen * hdim; + a.nhead_stride_q_descale = 0; + a.nhead_stride_k_descale = 0; + a.nhead_stride_v_descale = 0; + + a.batch_stride_q = nhead_q * seqlen * hdim; + a.batch_stride_k = nhead_k * seqlen * hdim; + a.batch_stride_v = nhead_k * seqlen * hdim; + a.batch_stride_bias = 0; + a.batch_stride_randval = 0; + a.batch_stride_lse = 0; + a.batch_stride_o = nhead_q * seqlen * hdim; + a.batch_stride_q_descale = 0; + a.batch_stride_k_descale = 0; + a.batch_stride_v_descale = 0; + + a.window_size_left = -1; + a.window_size_right = -1; + a.sink_size = 0; + a.mask_type = 0; + a.min_seqlen_q = 0; + a.p_drop = 0.0f; + a.s_randval = false; + a.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + a.block_scale_size_q = 0; + a.block_scale_size_kv = 0; + + return a; +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 21: GPU Features FMHA", "Dropout, GQA, LSE with real GPU data"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--seqlen", "64", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + print_header("Example 21: GPU Features FMHA"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("gpu_features_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector results; + + // ----------------------------------------------------------------------- + // Feature A: GQA (nhead_q=16, nhead_k=4, same basic kernel) + // ----------------------------------------------------------------------- + { + std::cout << "\nStep 2a: GQA (nhead_q=16, nhead_k=4)\n"; + const int nhead_q = 16; + const int nhead_k = 4; + + const int64_t q_elems = static_cast(batch) * nhead_q * seqlen * hdim; + const int64_t k_elems = static_cast(batch) * nhead_k * seqlen * hdim; + const int64_t o_elems = q_elems; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(k_elems); + GpuBuffer v_dev(k_elems); + GpuBuffer o_dev(o_elems); + + std::vector q_host(q_elems), k_host(k_elems), v_host(k_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + auto fmha_args = make_base_args(q_dev.get(), + k_dev.get(), + v_dev.get(), + o_dev.get(), + batch, + nhead_q, + nhead_k, + seqlen, + hdim, + scale); + + bool passed = false; + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + + // Validate against CPU reference with GQA head repetition + std::vector q_f32(q_elems), k_f32(k_elems), v_f32(k_elems); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + std::vector o_ref(o_elems, 0.0f); + cpu_attention_fwd(q_f32, + k_f32, + v_f32, + o_ref, + batch, + nhead_q, + nhead_k, + seqlen, + seqlen, + hdim, + hdim, + scale); + + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + double max_abs_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + max_abs_err = std::max(max_abs_err, abs_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + passed = (errors == 0); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + } + results.push_back({"GQA (16q/4k)", passed, time_ms}); + } + + // ----------------------------------------------------------------------- + // Feature B: LSE output + // ----------------------------------------------------------------------- + { + std::cout << "\nStep 2b: LSE Output\n"; + const int nhead = 4; + + const int64_t qkv_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t lse_elems = static_cast(batch) * nhead * seqlen; + + GpuBuffer q_dev(qkv_elems); + GpuBuffer k_dev(qkv_elems); + GpuBuffer v_dev(qkv_elems); + GpuBuffer o_dev(qkv_elems); + GpuBuffer lse_dev(lse_elems); + + std::vector q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + lse_dev.zero(); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = true; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + auto fmha_args = make_base_args(q_dev.get(), + k_dev.get(), + v_dev.get(), + o_dev.get(), + batch, + nhead, + nhead, + seqlen, + hdim, + scale); + fmha_args.lse_ptr = lse_dev.get(); + fmha_args.nhead_stride_lse = seqlen; + fmha_args.batch_stride_lse = nhead * seqlen; + + bool passed = false; + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + + // Compute CPU reference LSE + std::vector q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems); + for(int64_t i = 0; i < qkv_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < qkv_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < qkv_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + std::vector o_ref(qkv_elems, 0.0f); + std::vector lse_ref(lse_elems, 0.0f); + cpu_attention_fwd(q_f32, + k_f32, + v_f32, + o_ref, + batch, + nhead, + nhead, + seqlen, + seqlen, + hdim, + hdim, + scale, + &lse_ref); + + std::vector lse_host(lse_elems); + lse_dev.copy_to_host(lse_host.data()); + + int lse_reasonable = 0; + double max_lse_err = 0.0; + for(int64_t i = 0; i < lse_elems; ++i) + { + if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f) + ++lse_reasonable; + double err = std::abs(lse_host[i] - lse_ref[i]); + max_lse_err = std::max(max_lse_err, err); + } + std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n"; + std::cout << " LSE max error vs ref: " << std::scientific << max_lse_err << "\n"; + std::cout << " LSE sample [0..3]: "; + for(int i = 0; i < std::min(4, lse_elems); ++i) + std::cout << std::fixed << std::setprecision(4) << lse_host[i] << " "; + std::cout << "\n"; + passed = (lse_reasonable == lse_elems) && (max_lse_err < 1.0); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + } + results.push_back({"LSE", passed, time_ms}); + } + + // ----------------------------------------------------------------------- + // Feature C: Dropout + // ----------------------------------------------------------------------- + { + std::cout << "\nStep 2c: Dropout (p_drop=0.2)\n"; + const int nhead = 4; + + const int64_t qkv_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t lse_elems = static_cast(batch) * nhead * seqlen; + const int64_t randval_elems = static_cast(batch) * nhead * seqlen * seqlen; + + GpuBuffer q_dev(qkv_elems); + GpuBuffer k_dev(qkv_elems); + GpuBuffer v_dev(qkv_elems); + GpuBuffer o_dev(qkv_elems); + GpuBuffer lse_dev(lse_elems); + GpuBuffer rand_val_dev(randval_elems); + + std::vector q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + lse_dev.zero(); + rand_val_dev.zero(); + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = true; + traits.has_dropout = true; + traits.qscale_type = quant_scale_enum::no_scale; + + auto fmha_args = make_base_args(q_dev.get(), + k_dev.get(), + v_dev.get(), + o_dev.get(), + batch, + nhead, + nhead, + seqlen, + hdim, + scale); + fmha_args.lse_ptr = lse_dev.get(); + fmha_args.rand_val_ptr = rand_val_dev.get(); + fmha_args.nhead_stride_lse = seqlen; + fmha_args.batch_stride_lse = nhead * seqlen; + fmha_args.stride_randval = seqlen; + fmha_args.nhead_stride_randval = seqlen * seqlen; + fmha_args.batch_stride_randval = nhead * seqlen * seqlen; + fmha_args.p_drop = 0.2f; + fmha_args.s_randval = true; + fmha_args.drop_seed_offset = std::make_pair(uint64_t(42), uint64_t(0)); + + bool passed = false; + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + + std::vector o_host(qkv_elems); + o_dev.copy_to_host(o_host.data()); + + int nonzero = 0; + for(int64_t i = 0; i < qkv_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << qkv_elems << "\n"; + + std::vector lse_host(lse_elems); + lse_dev.copy_to_host(lse_host.data()); + int lse_reasonable = 0; + for(int64_t i = 0; i < lse_elems; ++i) + { + if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f) + ++lse_reasonable; + } + std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n"; + passed = (nonzero > 0) && (lse_reasonable == lse_elems); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + } + results.push_back({"Dropout", passed, time_ms}); + } + + // ----------------------------------------------------------------------- + // Summary + // ----------------------------------------------------------------------- + std::cout << "\nStep 3: Summary\n"; + std::cout << " " << std::setw(16) << "Feature" << " | " << std::setw(10) << "Time(ms)" << " | " + << std::setw(8) << "Status" << "\n"; + std::cout << " " << std::string(42, '-') << "\n"; + + bool all_passed = true; + for(const auto& r : results) + { + std::cout << " " << std::setw(16) << r.name << " | " << std::fixed << std::setprecision(4) + << std::setw(10) << r.time_ms << " | " << std::setw(8) + << (r.passed ? "PASS" : "FAIL") << "\n"; + if(!r.passed) + all_passed = false; + } + + print_separator(); + std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/22_gpu_bwd_fmha.cpp b/dispatcher/examples/fmha/cpp/22_gpu_bwd_fmha.cpp new file mode 100644 index 0000000000..4699346c5a --- /dev/null +++ b/dispatcher/examples/fmha/cpp/22_gpu_bwd_fmha.cpp @@ -0,0 +1,553 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 22: FMHA Backward with GPU Execution +// +// Demonstrates: +// 1. Declare 3 backward kernel families (bwd_dot_do_o, bwd_dq_dk_dv, bwd_convert_dq) +// 2. Run forward to get O and LSE +// 3. Run backward to compute dQ, dK, dV +// 4. Validate gradients are non-zero +// +// Falls back to planning only if backward kernels fail to compile on gfx950. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(gpu_bwd_fmha_kernels, + // Forward kernel (to produce O and LSE for backward) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Backward: dot(dO, O) to compute d scalar + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + // Backward: compute dQ, dK, dV + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + + // Backward: convert accumulated dQ from fp32 to fp16 + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + std::vector& LSE, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + int lse_idx = (b * nhead + h) * seqlen_q + sq; + LSE[lse_idx] = max_score + std::log(sum_exp); + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 22: FMHA Backward (GPU)", "Forward + backward with GPU validation"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "1", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 1); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + print_header("Example 22: FMHA Backward (GPU Execution)"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("gpu_bwd_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Step 2: Plan backward to verify all 3 stages resolve + std::cout << "\nStep 2: Plan Backward\n"; + + fmha_bwd_traits bwd_traits{}; + bwd_traits.hdim_q = hdim; + bwd_traits.hdim_v = hdim; + bwd_traits.data_type = "fp16"; + bwd_traits.is_group_mode = false; + bwd_traits.mask_type = mask_enum::no_mask; + bwd_traits.bias_type = bias_enum::no_bias; + bwd_traits.has_dbias = false; + bwd_traits.has_dropout = false; + bwd_traits.is_store_randval = false; + bwd_traits.is_deterministic = false; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = batch; + bwd_args.seqlen_q = seqlen; + bwd_args.seqlen_k = seqlen; + bwd_args.max_seqlen_q = seqlen; + bwd_args.max_seqlen_k = seqlen; + bwd_args.hdim_q = hdim; + bwd_args.hdim_v = hdim; + bwd_args.nhead_q = nhead; + bwd_args.nhead_k = nhead; + + auto bwd_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch)); + + if(!bwd_plan.is_valid() || bwd_plan.stages.size() < 2) + { + std::cout << " Backward plan: INVALID (expected multi-stage)\n"; + std::cout << " Falling back to planning-only mode (like 04_bwd_fmha.cpp)\n"; + print_separator(); + std::cout << "Status: PLAN_ONLY\n"; + print_separator(); + return 0; + } + + std::cout << " Backward plan stages:\n"; + for(const auto& stage : bwd_plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + // Step 3: Allocate buffers + std::cout << "\nStep 3: Allocate GPU Buffers\n"; + const int64_t qkv_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t lse_elems = static_cast(batch) * nhead * seqlen; + const int64_t dq_acc_elems = static_cast(batch) * nhead * seqlen * hdim; + + std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim + << "]\n"; + std::cout << " LSE/d: [" << batch << ", " << nhead << ", " << seqlen << "]\n"; + + GpuBuffer q_dev(qkv_elems); + GpuBuffer k_dev(qkv_elems); + GpuBuffer v_dev(qkv_elems); + GpuBuffer o_dev(qkv_elems); + GpuBuffer lse_dev(lse_elems); + GpuBuffer do_dev(qkv_elems); + GpuBuffer d_dev(lse_elems); + GpuBuffer dq_dev(qkv_elems); + GpuBuffer dk_dev(qkv_elems); + GpuBuffer dv_dev(qkv_elems); + GpuBuffer dq_acc_dev(dq_acc_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems); + std::vector do_host(qkv_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + for(auto& x : do_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + do_dev.copy_from_host(do_host.data()); + o_dev.zero(); + lse_dev.zero(); + d_dev.zero(); + dq_dev.zero(); + dk_dev.zero(); + dv_dev.zero(); + dq_acc_dev.zero(); + + // Step 4: Run forward to produce O and LSE + std::cout << "\nStep 4: Run Forward (to produce O and LSE)\n"; + { + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = hdim; + fwd_traits.hdim_v = hdim; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.has_logits_soft_cap = false; + fwd_traits.mask_type = mask_enum::no_mask; + fwd_traits.bias_type = bias_enum::no_bias; + fwd_traits.has_lse = true; + fwd_traits.has_dropout = false; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.q_ptr = q_dev.get(); + fwd_args.k_ptr = k_dev.get(); + fwd_args.v_ptr = v_dev.get(); + fwd_args.o_ptr = o_dev.get(); + fwd_args.lse_ptr = lse_dev.get(); + + fwd_args.bias_ptr = nullptr; + fwd_args.q_descale_ptr = nullptr; + fwd_args.k_descale_ptr = nullptr; + fwd_args.v_descale_ptr = nullptr; + fwd_args.rand_val_ptr = nullptr; + fwd_args.sink_ptr = nullptr; + fwd_args.block_scale_seqstart_q_ptr = nullptr; + fwd_args.block_scale_seqstart_k_ptr = nullptr; + + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.batch = batch; + fwd_args.max_seqlen_q = seqlen; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.scale_s = scale; + fwd_args.logits_soft_cap = 0.0f; + + fwd_args.stride_q = hdim; + fwd_args.stride_k = hdim; + fwd_args.stride_v = hdim; + fwd_args.stride_bias = 0; + fwd_args.stride_randval = 0; + fwd_args.stride_o = hdim; + + fwd_args.nhead_stride_q = seqlen * hdim; + fwd_args.nhead_stride_k = seqlen * hdim; + fwd_args.nhead_stride_v = seqlen * hdim; + fwd_args.nhead_stride_bias = 0; + fwd_args.nhead_stride_randval = 0; + fwd_args.nhead_stride_lse = seqlen; + fwd_args.nhead_stride_o = seqlen * hdim; + fwd_args.nhead_stride_q_descale = 0; + fwd_args.nhead_stride_k_descale = 0; + fwd_args.nhead_stride_v_descale = 0; + + fwd_args.batch_stride_q = nhead * seqlen * hdim; + fwd_args.batch_stride_k = nhead * seqlen * hdim; + fwd_args.batch_stride_v = nhead * seqlen * hdim; + fwd_args.batch_stride_bias = 0; + fwd_args.batch_stride_randval = 0; + fwd_args.batch_stride_lse = nhead * seqlen; + fwd_args.batch_stride_o = nhead * seqlen * hdim; + fwd_args.batch_stride_q_descale = 0; + fwd_args.batch_stride_k_descale = 0; + fwd_args.batch_stride_v_descale = 0; + + fwd_args.window_size_left = -1; + fwd_args.window_size_right = -1; + fwd_args.sink_size = 0; + fwd_args.mask_type = 0; + fwd_args.min_seqlen_q = 0; + fwd_args.p_drop = 0.0f; + fwd_args.s_randval = false; + fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fwd_args.block_scale_size_q = 0; + fwd_args.block_scale_size_kv = 0; + + try + { + float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr); + std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time + << " ms\n"; + } + catch(const std::exception& e) + { + std::cerr << " Forward ERROR: " << e.what() << "\n"; + print_separator(); + std::cout << "Status: FAIL (forward failed)\n"; + print_separator(); + return 1; + } + } + + // Step 5: Run backward + std::cout << "\nStep 5: Run Backward\n"; + + bwd_args.q_ptr = q_dev.get(); + bwd_args.k_ptr = k_dev.get(); + bwd_args.v_ptr = v_dev.get(); + bwd_args.bias_ptr = nullptr; + bwd_args.o_ptr = o_dev.get(); + bwd_args.lse_ptr = lse_dev.get(); + bwd_args.do_ptr = do_dev.get(); + bwd_args.d_ptr = d_dev.get(); + bwd_args.rand_val_ptr = nullptr; + bwd_args.dq_ptr = dq_dev.get(); + bwd_args.dk_ptr = dk_dev.get(); + bwd_args.dv_ptr = dv_dev.get(); + bwd_args.dbias_ptr = nullptr; + bwd_args.dq_acc_ptr = dq_acc_dev.get(); + bwd_args.scale = scale; + + bwd_args.stride_q = hdim; + bwd_args.stride_k = hdim; + bwd_args.stride_v = hdim; + bwd_args.stride_bias = 0; + bwd_args.stride_o = hdim; + bwd_args.stride_randval = 0; + bwd_args.stride_do = hdim; + bwd_args.stride_dq_acc = hdim; + bwd_args.stride_dq = hdim; + bwd_args.stride_dk = hdim; + bwd_args.stride_dv = hdim; + bwd_args.stride_dbias = 0; + + bwd_args.nhead_stride_q = seqlen * hdim; + bwd_args.nhead_stride_k = seqlen * hdim; + bwd_args.nhead_stride_v = seqlen * hdim; + bwd_args.nhead_stride_bias = 0; + bwd_args.nhead_stride_o = seqlen * hdim; + bwd_args.nhead_stride_randval = 0; + bwd_args.nhead_stride_do = seqlen * hdim; + bwd_args.nhead_stride_lsed = seqlen; + bwd_args.nhead_stride_dq_acc = static_cast(seqlen) * hdim; + bwd_args.nhead_stride_dq = seqlen * hdim; + bwd_args.nhead_stride_dk = seqlen * hdim; + bwd_args.nhead_stride_dv = seqlen * hdim; + bwd_args.nhead_stride_dbias = 0; + + bwd_args.batch_stride_q = nhead * seqlen * hdim; + bwd_args.batch_stride_k = nhead * seqlen * hdim; + bwd_args.batch_stride_v = nhead * seqlen * hdim; + bwd_args.batch_stride_bias = 0; + bwd_args.batch_stride_o = nhead * seqlen * hdim; + bwd_args.batch_stride_randval = 0; + bwd_args.batch_stride_do = nhead * seqlen * hdim; + bwd_args.batch_stride_lsed = nhead * seqlen; + bwd_args.batch_stride_dq_acc = static_cast(nhead) * seqlen * hdim; + bwd_args.batch_stride_dq = nhead * seqlen * hdim; + bwd_args.batch_stride_dk = nhead * seqlen * hdim; + bwd_args.batch_stride_dv = nhead * seqlen * hdim; + bwd_args.batch_stride_dbias = 0; + bwd_args.split_stride_dq_acc = 0; + + bwd_args.window_size_left = -1; + bwd_args.window_size_right = -1; + bwd_args.mask_type = 0; + bwd_args.p_drop = 0.0f; + bwd_args.p_undrop = 1.0f; + bwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + + bool bwd_passed = false; + try + { + float bwd_time = dispatcher.run_bwd(bwd_traits, bwd_args, nullptr); + std::cout << " Backward time: " << std::fixed << std::setprecision(4) << bwd_time + << " ms\n"; + + // Validate: dQ, dK, dV should be non-zero + std::vector dq_host(qkv_elems), dk_host(qkv_elems), dv_host(qkv_elems); + dq_dev.copy_to_host(dq_host.data()); + dk_dev.copy_to_host(dk_host.data()); + dv_dev.copy_to_host(dv_host.data()); + + auto count_nonzero = [](const std::vector& buf) { + int nz = 0; + for(const auto& x : buf) + { + if(static_cast(x) != 0.0f) + ++nz; + } + return nz; + }; + + int dq_nz = count_nonzero(dq_host); + int dk_nz = count_nonzero(dk_host); + int dv_nz = count_nonzero(dv_host); + + std::cout << " dQ non-zero: " << dq_nz << " / " << qkv_elems << "\n"; + std::cout << " dK non-zero: " << dk_nz << " / " << qkv_elems << "\n"; + std::cout << " dV non-zero: " << dv_nz << " / " << qkv_elems << "\n"; + + bwd_passed = (dq_nz > 0) && (dk_nz > 0) && (dv_nz > 0); + } + catch(const std::exception& e) + { + std::cerr << " Backward ERROR: " << e.what() << "\n"; + std::cout << " Falling back to planning-only mode (like 04_bwd_fmha.cpp)\n"; + std::cout << " Backward plan was valid with " << bwd_plan.stages.size() << " stages\n"; + print_separator(); + std::cout << "Status: PLAN_ONLY\n"; + print_separator(); + return 0; + } + + print_separator(); + std::cout << "Status: " << (bwd_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return bwd_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/23_multi_registry_fmha.cpp b/dispatcher/examples/fmha/cpp/23_multi_registry_fmha.cpp new file mode 100644 index 0000000000..0bc045078a --- /dev/null +++ b/dispatcher/examples/fmha/cpp/23_multi_registry_fmha.cpp @@ -0,0 +1,595 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 23: Multiple Registries for Different Frameworks +// +// Demonstrates: +// 1. Three separate FmhaRegistry instances (pytorch, flash, aiter) +// 2. Each with its own DECL_FMHA_KERNEL_SET using different configs +// 3. Registry introspection: size(), filter(), export_json() +// 4. Planning the same problem from each registry +// 5. GPU execution from the basic kernel registry +// +// Key idea: separate registries let each framework recipient own its +// kernel population independently. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +// Three DECL_FMHA_KERNEL_SETs with distinct names and configurations. +// All register into the global FmhaKernelSetRegistry at static init time. + +DECL_FMHA_KERNEL_SET(pytorch_reg_kernels, + // PyTorch: basic fp16, elementwise bias + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("bias") + .lse(false) + .dropout(false) + .qscale("no") + .profile("pytorch"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +DECL_FMHA_KERNEL_SET(flash_reg_kernels, + // Flash: fp16, alibi bias + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("alibi") + .lse(false) + .dropout(false) + .qscale("no") + .profile("flash_fwd"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +DECL_FMHA_KERNEL_SET(aiter_reg_kernels, + // AITER: batch mode basic + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no") + .profile("aiter_batch"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + // AITER: group mode + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no") + .profile("aiter_group"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + for(int sk = 0; sk < seqlen_k; ++sk) + scores[sk] /= sum_exp; + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +struct RegistryInfo +{ + std::string name; + FmhaRegistry* reg; + FmhaDispatcher* disp; +}; + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 23: Multi-Registry FMHA", + "Separate registries per framework recipient"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + print_header("Example 23: Multi-Registry FMHA"); + + // Step 1: Create 3 separate registries + std::cout << "\nStep 1: Create Separate Registries\n"; + std::cout << " Global kernel sets declared: " << FmhaKernelSetRegistry::instance().size() + << "\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry pytorch_reg; + pytorch_reg.set_name("pytorch"); + REGISTER_GENERATED_KERNELS(pytorch_reg, gfx_arch); + + FmhaRegistry flash_reg; + flash_reg.set_name("flash"); + REGISTER_GENERATED_KERNELS(flash_reg, gfx_arch); + + FmhaRegistry aiter_reg; + aiter_reg.set_name("aiter"); + REGISTER_GENERATED_KERNELS(aiter_reg, gfx_arch); + + FmhaDispatcher pytorch_disp(&pytorch_reg); + FmhaDispatcher flash_disp(&flash_reg); + FmhaDispatcher aiter_disp(&aiter_reg); + + std::vector registries = { + {"pytorch", &pytorch_reg, &pytorch_disp}, + {"flash", &flash_reg, &flash_disp}, + {"aiter", &aiter_reg, &aiter_disp}, + }; + + // Step 2: Registry introspection + std::cout << "\nStep 2: Registry Introspection\n"; + for(const auto& ri : registries) + { + std::cout << "\n Registry: " << ri.name << "\n"; + std::cout << " Kernel count: " << ri.reg->size() << "\n"; + + auto all_kernels = ri.reg->get_all(); + for(const auto& k : all_kernels) + { + std::cout << " Kernel: " << k->get_name() << "\n"; + } + + auto fwd_kernels = ri.reg->filter([](const FmhaKernelInstance& inst) { + return inst.get_key().signature.family == FmhaKernelFamily::Fwd; + }); + std::cout << " Forward kernels: " << fwd_kernels.size() << "\n"; + + std::string json = ri.reg->export_json(false); + std::cout << " JSON size: " << json.size() << " bytes\n"; + } + + // Step 3: Plan the same problem from each registry + std::cout << "\nStep 3: Plan from Each Registry\n"; + + // Problem A: basic fp16 no-bias (matches aiter_batch) + { + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + + std::cout << "\n Problem: fp16 batch no-bias\n"; + for(const auto& ri : registries) + { + auto plan = ri.disp->plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + std::cout << " " << ri.name << ": " + << (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n"; + } + } + + // Problem B: fp16 with alibi bias (matches flash) + { + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::alibi; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + + std::cout << "\n Problem: fp16 batch alibi-bias\n"; + for(const auto& ri : registries) + { + auto plan = ri.disp->plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + std::cout << " " << ri.name << ": " + << (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n"; + } + } + + // Problem C: fp16 with elementwise bias (matches pytorch) + { + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::elementwise_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + + std::cout << "\n Problem: fp16 batch elementwise-bias\n"; + for(const auto& ri : registries) + { + auto plan = ri.disp->plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + std::cout << " " << ri.name << ": " + << (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n"; + } + } + + // Step 4: GPU execution from AITER registry (basic no-bias kernel) + std::cout << "\nStep 4: GPU Execution (aiter registry)\n"; + + const int64_t q_elems = static_cast(batch) * nhead * seqlen * hdim; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(q_elems); + GpuBuffer v_dev(q_elems); + GpuBuffer o_dev(q_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems), k_host(q_elems), v_host(q_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + fmha_fwd_traits run_traits{}; + run_traits.hdim_q = hdim; + run_traits.hdim_v = hdim; + run_traits.data_type = "fp16"; + run_traits.is_group_mode = false; + run_traits.is_v_rowmajor = true; + run_traits.has_logits_soft_cap = false; + run_traits.mask_type = mask_enum::no_mask; + run_traits.bias_type = bias_enum::no_bias; + run_traits.has_lse = false; + run_traits.has_dropout = false; + run_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args run_args{}; + run_args.q_ptr = q_dev.get(); + run_args.k_ptr = k_dev.get(); + run_args.v_ptr = v_dev.get(); + run_args.o_ptr = o_dev.get(); + + run_args.bias_ptr = nullptr; + run_args.q_descale_ptr = nullptr; + run_args.k_descale_ptr = nullptr; + run_args.v_descale_ptr = nullptr; + run_args.rand_val_ptr = nullptr; + run_args.lse_ptr = nullptr; + run_args.sink_ptr = nullptr; + run_args.block_scale_seqstart_q_ptr = nullptr; + run_args.block_scale_seqstart_k_ptr = nullptr; + + run_args.seqlen_q = seqlen; + run_args.seqlen_k = seqlen; + run_args.batch = batch; + run_args.max_seqlen_q = seqlen; + run_args.hdim_q = hdim; + run_args.hdim_v = hdim; + run_args.nhead_q = nhead; + run_args.nhead_k = nhead; + run_args.scale_s = scale; + run_args.logits_soft_cap = 0.0f; + + run_args.stride_q = hdim; + run_args.stride_k = hdim; + run_args.stride_v = hdim; + run_args.stride_bias = 0; + run_args.stride_randval = 0; + run_args.stride_o = hdim; + + run_args.nhead_stride_q = seqlen * hdim; + run_args.nhead_stride_k = seqlen * hdim; + run_args.nhead_stride_v = seqlen * hdim; + run_args.nhead_stride_bias = 0; + run_args.nhead_stride_randval = 0; + run_args.nhead_stride_lse = 0; + run_args.nhead_stride_o = seqlen * hdim; + run_args.nhead_stride_q_descale = 0; + run_args.nhead_stride_k_descale = 0; + run_args.nhead_stride_v_descale = 0; + + run_args.batch_stride_q = nhead * seqlen * hdim; + run_args.batch_stride_k = nhead * seqlen * hdim; + run_args.batch_stride_v = nhead * seqlen * hdim; + run_args.batch_stride_bias = 0; + run_args.batch_stride_randval = 0; + run_args.batch_stride_lse = 0; + run_args.batch_stride_o = nhead * seqlen * hdim; + run_args.batch_stride_q_descale = 0; + run_args.batch_stride_k_descale = 0; + run_args.batch_stride_v_descale = 0; + + run_args.window_size_left = -1; + run_args.window_size_right = -1; + run_args.sink_size = 0; + run_args.mask_type = 0; + run_args.min_seqlen_q = 0; + run_args.p_drop = 0.0f; + run_args.s_randval = false; + run_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + run_args.block_scale_size_q = 0; + run_args.block_scale_size_kv = 0; + + bool passed = false; + aiter_disp.set_benchmarking(true); + aiter_disp.set_timing(1, 3); + try + { + float time_ms = aiter_disp.run_fwd(run_traits, run_args, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + + std::vector o_host(q_elems); + o_dev.copy_to_host(o_host.data()); + + // Validate + std::vector q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(q_elems, 0.0f); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < q_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < q_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + double max_abs_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + for(int64_t i = 0; i < q_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + max_abs_err = std::max(max_abs_err, abs_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Errors: " << errors << " / " << q_elems << "\n"; + passed = (errors == 0); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/24_per_receipt_registries_fmha.cpp b/dispatcher/examples/fmha/cpp/24_per_receipt_registries_fmha.cpp new file mode 100644 index 0000000000..926c8e4601 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/24_per_receipt_registries_fmha.cpp @@ -0,0 +1,549 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 24: Per-Receipt Registries +// +// Demonstrates: +// 1. Four DECL_FMHA_KERNEL_SET declarations, each named after a receipt +// 2. Each registered into a separate FmhaRegistry +// 3. Per-registry: kernel count, kernel names, plan a problem, selected kernel +// 4. GPU execution from the ck_default receipt (the basic working kernel) +// 5. Comparison table showing which features each receipt supports +// +// Receipt = a curated kernel set shipped to a specific downstream consumer. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +// Receipt 1: CK default -- basic fp16, no mask, no bias +DECL_FMHA_KERNEL_SET(ck_default_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +// Receipt 2: Flash forward -- fp16 with alibi bias +DECL_FMHA_KERNEL_SET(flash_fwd_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("alibi") + .lse(false) + .dropout(false) + .qscale("no") + .profile("flash_fwd"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +// Receipt 3: PyTorch -- fp16 with elementwise bias +DECL_FMHA_KERNEL_SET(pytorch_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("bias") + .lse(false) + .dropout(false) + .qscale("no") + .profile("pytorch"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +// Receipt 4: AITER batch -- fp16 batch mode with LSE +DECL_FMHA_KERNEL_SET(aiter_batch_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .profile("aiter_batch"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +struct ReceiptInfo +{ + std::string name; + std::string bias_desc; + bool has_lse; + FmhaRegistry registry; +}; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + for(int sk = 0; sk < seqlen_k; ++sk) + scores[sk] /= sum_exp; + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 24: Per-Receipt Registries", + "Curated kernel sets per downstream consumer"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + print_header("Example 24: Per-Receipt Registries"); + + // Step 1: Create per-receipt registries + std::cout << "\nStep 1: Create Per-Receipt Registries\n"; + std::cout << " Global kernel sets: " << FmhaKernelSetRegistry::instance().size() << "\n"; + + std::vector receipts; + + receipts.push_back({"ck_default", "none", false, FmhaRegistry()}); + receipts.back().registry.set_name("ck_default"); + REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch); + + receipts.push_back({"flash_fwd", "alibi", false, FmhaRegistry()}); + receipts.back().registry.set_name("flash_fwd"); + REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch); + + receipts.push_back({"pytorch", "elementwise", false, FmhaRegistry()}); + receipts.back().registry.set_name("pytorch"); + REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch); + + receipts.push_back({"aiter_batch", "none", true, FmhaRegistry()}); + receipts.back().registry.set_name("aiter_batch"); + REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch); + + // Step 2: Per-registry introspection + std::cout << "\nStep 2: Per-Receipt Introspection\n"; + for(auto& r : receipts) + { + std::cout << "\n Receipt: " << r.name << "\n"; + std::cout << " Kernel count: " << r.registry.size() << "\n"; + + auto all = r.registry.get_all(); + for(const auto& k : all) + { + std::cout << " Kernel: " << k->get_name() << "\n"; + } + } + + // Step 3: Plan a matching problem for each receipt + std::cout << "\nStep 3: Plan per Receipt\n"; + + struct PlanTest + { + std::string receipt_name; + bias_enum bias; + bool lse; + }; + std::vector plan_tests = { + {"ck_default", bias_enum::no_bias, false}, + {"flash_fwd", bias_enum::alibi, false}, + {"pytorch", bias_enum::elementwise_bias, false}, + {"aiter_batch", bias_enum::no_bias, true}, + }; + + for(std::size_t i = 0; i < plan_tests.size(); ++i) + { + const auto& pt = plan_tests[i]; + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = pt.bias; + traits.has_lse = pt.lse; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fmha_args{}; + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen; + fmha_args.seqlen_k = seqlen; + fmha_args.max_seqlen_q = seqlen; + fmha_args.hdim_q = hdim; + fmha_args.hdim_v = hdim; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead; + + FmhaDispatcher disp(&receipts[i].registry); + auto plan = disp.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch)); + + std::cout << " " << pt.receipt_name << ": " + << (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n"; + } + + // Step 4: Comparison table + std::cout << "\nStep 4: Receipt Feature Comparison\n\n"; + std::cout << " " << std::setw(14) << "Receipt" << " | " << std::setw(14) << "Bias" << " | " + << std::setw(5) << "LSE" << " | " << std::setw(8) << "Kernels" << "\n"; + std::cout << " " << std::string(50, '-') << "\n"; + + struct CompRow + { + std::string name; + std::string bias; + std::string lse; + std::size_t count; + }; + std::vector comp = { + {"ck_default", "none", "no", receipts[0].registry.size()}, + {"flash_fwd", "alibi", "no", receipts[1].registry.size()}, + {"pytorch", "elementwise", "no", receipts[2].registry.size()}, + {"aiter_batch", "none", "yes", receipts[3].registry.size()}, + }; + + for(const auto& c : comp) + { + std::cout << " " << std::setw(14) << c.name << " | " << std::setw(14) << c.bias << " | " + << std::setw(5) << c.lse << " | " << std::setw(8) << c.count << "\n"; + } + + // Step 5: GPU execution from ck_default + std::cout << "\nStep 5: GPU Execution (ck_default receipt)\n"; + + const int64_t q_elems = static_cast(batch) * nhead * seqlen * hdim; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(q_elems); + GpuBuffer v_dev(q_elems); + GpuBuffer o_dev(q_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems), k_host(q_elems), v_host(q_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + fmha_fwd_traits run_traits{}; + run_traits.hdim_q = hdim; + run_traits.hdim_v = hdim; + run_traits.data_type = "fp16"; + run_traits.is_group_mode = false; + run_traits.is_v_rowmajor = true; + run_traits.has_logits_soft_cap = false; + run_traits.mask_type = mask_enum::no_mask; + run_traits.bias_type = bias_enum::no_bias; + run_traits.has_lse = false; + run_traits.has_dropout = false; + run_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args run_args{}; + run_args.q_ptr = q_dev.get(); + run_args.k_ptr = k_dev.get(); + run_args.v_ptr = v_dev.get(); + run_args.o_ptr = o_dev.get(); + + run_args.bias_ptr = nullptr; + run_args.q_descale_ptr = nullptr; + run_args.k_descale_ptr = nullptr; + run_args.v_descale_ptr = nullptr; + run_args.rand_val_ptr = nullptr; + run_args.lse_ptr = nullptr; + run_args.sink_ptr = nullptr; + run_args.block_scale_seqstart_q_ptr = nullptr; + run_args.block_scale_seqstart_k_ptr = nullptr; + + run_args.seqlen_q = seqlen; + run_args.seqlen_k = seqlen; + run_args.batch = batch; + run_args.max_seqlen_q = seqlen; + run_args.hdim_q = hdim; + run_args.hdim_v = hdim; + run_args.nhead_q = nhead; + run_args.nhead_k = nhead; + run_args.scale_s = scale; + run_args.logits_soft_cap = 0.0f; + + run_args.stride_q = hdim; + run_args.stride_k = hdim; + run_args.stride_v = hdim; + run_args.stride_bias = 0; + run_args.stride_randval = 0; + run_args.stride_o = hdim; + + run_args.nhead_stride_q = seqlen * hdim; + run_args.nhead_stride_k = seqlen * hdim; + run_args.nhead_stride_v = seqlen * hdim; + run_args.nhead_stride_bias = 0; + run_args.nhead_stride_randval = 0; + run_args.nhead_stride_lse = 0; + run_args.nhead_stride_o = seqlen * hdim; + run_args.nhead_stride_q_descale = 0; + run_args.nhead_stride_k_descale = 0; + run_args.nhead_stride_v_descale = 0; + + run_args.batch_stride_q = nhead * seqlen * hdim; + run_args.batch_stride_k = nhead * seqlen * hdim; + run_args.batch_stride_v = nhead * seqlen * hdim; + run_args.batch_stride_bias = 0; + run_args.batch_stride_randval = 0; + run_args.batch_stride_lse = 0; + run_args.batch_stride_o = nhead * seqlen * hdim; + run_args.batch_stride_q_descale = 0; + run_args.batch_stride_k_descale = 0; + run_args.batch_stride_v_descale = 0; + + run_args.window_size_left = -1; + run_args.window_size_right = -1; + run_args.sink_size = 0; + run_args.mask_type = 0; + run_args.min_seqlen_q = 0; + run_args.p_drop = 0.0f; + run_args.s_randval = false; + run_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + run_args.block_scale_size_q = 0; + run_args.block_scale_size_kv = 0; + + FmhaDispatcher ck_disp(&receipts[0].registry); + ck_disp.set_benchmarking(true); + ck_disp.set_timing(1, 3); + + bool passed = false; + try + { + float time_ms = ck_disp.run_fwd(run_traits, run_args, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + + std::vector o_host(q_elems); + o_dev.copy_to_host(o_host.data()); + + std::vector q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(q_elems, 0.0f); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < q_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < q_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + double max_abs_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + for(int64_t i = 0; i < q_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + max_abs_err = std::max(max_abs_err, abs_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Errors: " << errors << " / " << q_elems << "\n"; + passed = (errors == 0); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/25_gpu_appendkv_batchprefill_fmha.cpp b/dispatcher/examples/fmha/cpp/25_gpu_appendkv_batchprefill_fmha.cpp new file mode 100644 index 0000000000..db47698b80 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/25_gpu_appendkv_batchprefill_fmha.cpp @@ -0,0 +1,530 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 25: AppendKV + BatchPrefill Planning with GPU Execution +// +// Demonstrates: +// 1. Declare appendkv, batch_prefill, and basic fwd kernels +// 2. Plan appendkv with fmha_fwd_appendkv_traits / fmha_fwd_appendkv_args +// 3. Plan batch_prefill with fmha_batch_prefill_traits / fmha_batch_prefill_args +// 4. Run basic fwd kernel on GPU as sanity check +// 5. Show cache_batch_idx usage pattern for non-contiguous batches +// +// Mirrors 01_basic_fmha.cpp for FMHA. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(appendkv_batchprefill_kernels, + + // AppendKV kernel + .add(FmhaSignature() + .family("fwd_appendkv") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .rope("inter") + .paged_kv(true) + .kv_cache("vectorized", "sglang", 16), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(64) + .tile_n0(64) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .pipeline("appendkv") + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + // BatchPrefill kernel (group mode, paged KV, page_size=64) + .add(FmhaSignature() + .family("batch_prefill") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no") + .paged_kv(true) + .kv_cache("vectorized", "sglang", 64), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + // Basic fwd kernel for GPU execution sanity check + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 25: AppendKV + BatchPrefill + GPU", + "FMHA AppendKV/BatchPrefill planning with GPU sanity check"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length (Q and K)"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_flag("--validate", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 25: AppendKV + BatchPrefill + GPU Execution"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("appendkv_batchprefill"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // ========================================================================= + // Step 2: Plan AppendKV + // traits: fmha_fwd_appendkv_traits (hdim_q, hdim_v, data_type, + // is_v_rowmajor, rope_type) + // args: fmha_fwd_appendkv_args (q_ptr, k_ptr, knew_ptr, v_ptr, + // vnew_ptr, seqlen_q, seqlen_knew, ...) + // ========================================================================= + std::cout << "\nStep 2: Plan AppendKV\n"; + + fmha_fwd_appendkv_traits append_traits{}; + append_traits.hdim_q = hdim; + append_traits.hdim_v = hdim; + append_traits.data_type = "fp16"; + append_traits.is_v_rowmajor = true; + append_traits.rope_type = rope_enum::interleaved; + + fmha_fwd_appendkv_args append_args{}; + append_args.q_ptr = reinterpret_cast(0x1); + append_args.k_ptr = reinterpret_cast(0x1); + append_args.knew_ptr = reinterpret_cast(0x1); + append_args.v_ptr = reinterpret_cast(0x1); + append_args.vnew_ptr = reinterpret_cast(0x1); + append_args.seqlen_q = 1; + append_args.seqlen_knew = 1; + append_args.batch = batch; + append_args.hdim_q = hdim; + append_args.hdim_v = hdim; + append_args.nhead_q = nhead; + append_args.nhead_k = nhead; + append_args.rotary_dim = hdim; + append_args.rotary_cos_ptr = reinterpret_cast(0x1); + append_args.rotary_sin_ptr = reinterpret_cast(0x1); + append_args.block_table_ptr = reinterpret_cast(0x1); + append_args.page_block_size = 16; + + // cache_batch_idx: maps request index -> cache slot for non-contiguous batches. + // When serving multiple requests that don't occupy contiguous cache slots, + // this indirection array tells the kernel which cache row each request maps to. + append_args.cache_batch_idx_ptr = reinterpret_cast(0x1); + + auto append_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(append_traits, append_args), gfx_arch)); + + std::cout << " AppendKV plan valid: " << (append_plan.is_valid() ? "yes" : "no") << "\n"; + if(append_plan.is_valid()) + { + for(const auto& stage : append_plan.stages) + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + // ========================================================================= + // Step 3: Plan BatchPrefill + // traits: fmha_batch_prefill_traits (extends fmha_fwd_traits with + // kv_memory_layout, kv_lookup_table, page_size) + // args: fmha_batch_prefill_args (kv_indptr, kv_page_indices, + // kv_last_page_lens, seqstart_q_ptr, ...) + // ========================================================================= + std::cout << "\nStep 3: Plan BatchPrefill\n"; + + fmha_batch_prefill_traits prefill_traits{}; + prefill_traits.hdim_q = hdim; + prefill_traits.hdim_v = hdim; + prefill_traits.data_type = "fp16"; + prefill_traits.is_group_mode = true; + prefill_traits.is_v_rowmajor = true; + prefill_traits.mask_type = mask_enum::no_mask; + prefill_traits.bias_type = bias_enum::no_bias; + prefill_traits.has_lse = true; + prefill_traits.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + prefill_traits.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + prefill_traits.page_size = 64; + + fmha_batch_prefill_args prefill_args{}; + prefill_args.batch = batch; + prefill_args.seqlen_q = seqlen; + prefill_args.seqlen_k = 1024; + prefill_args.max_seqlen_q = seqlen; + prefill_args.hdim_q = hdim; + prefill_args.hdim_v = hdim; + prefill_args.nhead_q = nhead; + prefill_args.nhead_k = nhead; + prefill_args.num_total_pages = 128; + prefill_args.page_block_size = 64; + prefill_args.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + prefill_args.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + prefill_args.kv_indptr = reinterpret_cast(0x1); + prefill_args.kv_page_indices = reinterpret_cast(0x1); + prefill_args.kv_last_page_lens = reinterpret_cast(0x1); + prefill_args.seqstart_q_ptr = reinterpret_cast(0x1); + + auto prefill_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch)); + + std::cout << " BatchPrefill plan valid: " << (prefill_plan.is_valid() ? "yes" : "no") << "\n"; + if(prefill_plan.is_valid()) + { + for(const auto& stage : prefill_plan.stages) + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + + // ========================================================================= + // Step 4: GPU Execution with basic fwd kernel (sanity check) + // ========================================================================= + std::cout << "\nStep 4: Allocate GPU Buffers\n"; + + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = hdim; + fwd_traits.hdim_v = hdim; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.has_logits_soft_cap = false; + fwd_traits.mask_type = mask_enum::no_mask; + fwd_traits.bias_type = bias_enum::no_bias; + fwd_traits.has_lse = false; + fwd_traits.has_dropout = false; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + const int64_t q_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t k_elems = q_elems; + const int64_t v_elems = q_elems; + const int64_t o_elems = q_elems; + + std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim + << "]\n"; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(k_elems); + GpuBuffer v_dev(v_elems); + GpuBuffer o_dev(o_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems); + std::vector k_host(k_elems); + std::vector v_host(v_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + fmha_fwd_args fwd_args{}; + fwd_args.q_ptr = q_dev.get(); + fwd_args.k_ptr = k_dev.get(); + fwd_args.v_ptr = v_dev.get(); + fwd_args.o_ptr = o_dev.get(); + + fwd_args.bias_ptr = nullptr; + fwd_args.q_descale_ptr = nullptr; + fwd_args.k_descale_ptr = nullptr; + fwd_args.v_descale_ptr = nullptr; + fwd_args.rand_val_ptr = nullptr; + fwd_args.lse_ptr = nullptr; + fwd_args.sink_ptr = nullptr; + fwd_args.block_scale_seqstart_q_ptr = nullptr; + fwd_args.block_scale_seqstart_k_ptr = nullptr; + + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.batch = batch; + fwd_args.max_seqlen_q = seqlen; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.scale_s = scale; + fwd_args.logits_soft_cap = 0.0f; + + fwd_args.stride_q = hdim; + fwd_args.stride_k = hdim; + fwd_args.stride_v = hdim; + fwd_args.stride_bias = 0; + fwd_args.stride_randval = 0; + fwd_args.stride_o = hdim; + + fwd_args.nhead_stride_q = seqlen * hdim; + fwd_args.nhead_stride_k = seqlen * hdim; + fwd_args.nhead_stride_v = seqlen * hdim; + fwd_args.nhead_stride_bias = 0; + fwd_args.nhead_stride_randval = 0; + fwd_args.nhead_stride_lse = 0; + fwd_args.nhead_stride_o = seqlen * hdim; + fwd_args.nhead_stride_q_descale = 0; + fwd_args.nhead_stride_k_descale = 0; + fwd_args.nhead_stride_v_descale = 0; + + fwd_args.batch_stride_q = nhead * seqlen * hdim; + fwd_args.batch_stride_k = nhead * seqlen * hdim; + fwd_args.batch_stride_v = nhead * seqlen * hdim; + fwd_args.batch_stride_bias = 0; + fwd_args.batch_stride_randval = 0; + fwd_args.batch_stride_lse = 0; + fwd_args.batch_stride_o = nhead * seqlen * hdim; + fwd_args.batch_stride_q_descale = 0; + fwd_args.batch_stride_k_descale = 0; + fwd_args.batch_stride_v_descale = 0; + + fwd_args.window_size_left = -1; + fwd_args.window_size_right = -1; + fwd_args.sink_size = 0; + fwd_args.mask_type = 0; + fwd_args.min_seqlen_q = 0; + fwd_args.p_drop = 0.0f; + fwd_args.s_randval = false; + fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fwd_args.block_scale_size_q = 0; + fwd_args.block_scale_size_kv = 0; + + // Step 5: Run on GPU + std::cout << "\nStep 5: Run FMHA Forward on GPU\n"; + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + return 1; + } + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 6: Validate + std::cout << "\nStep 6: Validate\n"; + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + int nonzero = 0; + for(int64_t i = 0; i < o_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n"; + + bool passed = (nonzero > 0); + + if(args.has("--validate")) + { + std::vector q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < v_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + double max_abs_err = 0.0; + double max_rel_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Max rel error: " << max_rel_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + passed = (errors == 0); + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/26_dtypes_hdims_fmha.cpp b/dispatcher/examples/fmha/cpp/26_dtypes_hdims_fmha.cpp new file mode 100644 index 0000000000..ff77dcbb25 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/26_dtypes_hdims_fmha.cpp @@ -0,0 +1,526 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 26: Multiple Data Types and Head Dimensions with GPU Execution +// +// Demonstrates: +// 1. Declare bf16 hdim=128, fp16 hdim=64, and fp16 hdim=128 kernels +// 2. Run each variant on GPU with appropriate buffer types +// 3. Validate with different tolerances: fp16 (rtol=1e-3), bf16 (rtol=1e-2) +// 4. Mention fp32, fp8bf16, fp8fp32, hdim 256, asymmetric hdim as planning +// +// Mirrors 01_basic_fmha.cpp for FMHA. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(dtypes_hdims_kernels, + + // bf16 hdim=128 + .add(FmhaSignature() + .family("fwd") + .dtype("bf16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // fp16 hdim=64 + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(64) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(64) + .tile_k0(32) + .tile_n1(64) + .tile_k1(32) + .tile_k0max(64) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(16) + .warp_n1(16) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(64, 64) + .selection_rank(0), + "gfx950") + + // fp16 hdim=128 (reference baseline) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +using Fp16Type = ck_tile::fp16_t; +using Bf16Type = ck_tile::bf16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +struct VariantResult +{ + std::string label; + float time_ms; + double tflops; + double max_abs_err; + double max_rel_err; + int errors; + bool passed; +}; + +template +fmha_fwd_args make_fwd_args(GpuBuffer& q_dev, + GpuBuffer& k_dev, + GpuBuffer& v_dev, + GpuBuffer& o_dev, + int batch, + int nhead, + int seqlen, + int hdim, + float scale) +{ + fmha_fwd_args a{}; + a.q_ptr = q_dev.get(); + a.k_ptr = k_dev.get(); + a.v_ptr = v_dev.get(); + a.o_ptr = o_dev.get(); + + a.bias_ptr = nullptr; + a.q_descale_ptr = nullptr; + a.k_descale_ptr = nullptr; + a.v_descale_ptr = nullptr; + a.rand_val_ptr = nullptr; + a.lse_ptr = nullptr; + a.sink_ptr = nullptr; + a.block_scale_seqstart_q_ptr = nullptr; + a.block_scale_seqstart_k_ptr = nullptr; + + a.seqlen_q = seqlen; + a.seqlen_k = seqlen; + a.batch = batch; + a.max_seqlen_q = seqlen; + a.hdim_q = hdim; + a.hdim_v = hdim; + a.nhead_q = nhead; + a.nhead_k = nhead; + a.scale_s = scale; + a.logits_soft_cap = 0.0f; + + a.stride_q = hdim; + a.stride_k = hdim; + a.stride_v = hdim; + a.stride_bias = 0; + a.stride_randval = 0; + a.stride_o = hdim; + + a.nhead_stride_q = seqlen * hdim; + a.nhead_stride_k = seqlen * hdim; + a.nhead_stride_v = seqlen * hdim; + a.nhead_stride_bias = 0; + a.nhead_stride_randval = 0; + a.nhead_stride_lse = 0; + a.nhead_stride_o = seqlen * hdim; + a.nhead_stride_q_descale = 0; + a.nhead_stride_k_descale = 0; + a.nhead_stride_v_descale = 0; + + a.batch_stride_q = nhead * seqlen * hdim; + a.batch_stride_k = nhead * seqlen * hdim; + a.batch_stride_v = nhead * seqlen * hdim; + a.batch_stride_bias = 0; + a.batch_stride_randval = 0; + a.batch_stride_lse = 0; + a.batch_stride_o = nhead * seqlen * hdim; + a.batch_stride_q_descale = 0; + a.batch_stride_k_descale = 0; + a.batch_stride_v_descale = 0; + + a.window_size_left = -1; + a.window_size_right = -1; + a.sink_size = 0; + a.mask_type = 0; + a.min_seqlen_q = 0; + a.p_drop = 0.0f; + a.s_randval = false; + a.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + a.block_scale_size_q = 0; + a.block_scale_size_kv = 0; + + return a; +} + +template +VariantResult run_variant(FmhaDispatcher& dispatcher, + const std::string& label, + const std::string& dtype_str, + int batch, + int nhead, + int seqlen, + int hdim, + double rtol, + double atol, + const std::string& gfx_arch) +{ + VariantResult result{}; + result.label = label; + + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + const int64_t elems = static_cast(batch) * nhead * seqlen * hdim; + + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = dtype_str; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + GpuBuffer q_dev(elems); + GpuBuffer k_dev(elems); + GpuBuffer v_dev(elems); + GpuBuffer o_dev(elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(elems); + std::vector k_host(elems); + std::vector v_host(elems); + for(auto& x : q_host) + x = DataType(dist(rng)); + for(auto& x : k_host) + x = DataType(dist(rng)); + for(auto& x : v_host) + x = DataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + auto fwd_args = make_fwd_args(q_dev, k_dev, v_dev, o_dev, batch, nhead, seqlen, hdim, scale); + + try + { + result.time_ms = dispatcher.run_fwd(traits, fwd_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " ERROR [" << label << "]: " << e.what() << "\n"; + result.passed = false; + return result; + } + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch); + result.tflops = static_cast(problem.num_ops()) / (result.time_ms * 1e-3) / 1e12; + + std::vector o_host(elems); + o_dev.copy_to_host(o_host.data()); + + std::vector q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f); + for(int64_t i = 0; i < elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd(q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + result.max_abs_err = 0.0; + result.max_rel_err = 0.0; + result.errors = 0; + + for(int64_t i = 0; i < elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + result.max_abs_err = std::max(result.max_abs_err, abs_err); + result.max_rel_err = std::max(result.max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++result.errors; + } + + result.passed = (result.errors == 0); + return result; +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 26: Dtypes & Hdims FMHA", + "FMHA with multiple data types and head dimensions"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length (Q and K)"); + args.add_flag("--validate", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + + print_header("Example 26: Multiple Data Types & Head Dimensions"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("dtypes_hdims"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // ========================================================================= + // Step 2: Run variants on GPU + // ========================================================================= + std::cout << "\nStep 2: Run Variants\n"; + + // fp16 hdim=128 (reference baseline) + std::cout << "\n --- fp16 hdim=128 (reference) ---\n"; + auto r_fp16_h128 = run_variant(dispatcher, + "fp16_h128", + "fp16", + batch, + nhead, + seqlen, + 128, + /*rtol=*/1e-3, + /*atol=*/1e-3, + gfx_arch); + + // bf16 hdim=128 (wider tolerance due to reduced precision) + std::cout << "\n --- bf16 hdim=128 ---\n"; + auto r_bf16_h128 = run_variant(dispatcher, + "bf16_h128", + "bf16", + batch, + nhead, + seqlen, + 128, + /*rtol=*/1e-2, + /*atol=*/1e-2, + gfx_arch); + + // fp16 hdim=64 (smaller buffers) + std::cout << "\n --- fp16 hdim=64 ---\n"; + auto r_fp16_h64 = run_variant(dispatcher, + "fp16_h64", + "fp16", + batch, + nhead, + seqlen, + 64, + /*rtol=*/1e-3, + /*atol=*/1e-3, + gfx_arch); + + // ========================================================================= + // Step 3: Results Summary + // ========================================================================= + std::cout << "\nStep 3: Results Summary\n\n"; + + std::cout << " " << std::setw(14) << "Variant" << " | " << std::setw(10) << "Time(ms)" << " | " + << std::setw(10) << "TFLOPS" << " | " << std::setw(10) << "MaxAbsErr" << " | " + << std::setw(10) << "MaxRelErr" << " | " << std::setw(8) << "Errors" << " | " + << std::setw(6) << "Status" << "\n"; + std::cout << " " << std::string(82, '-') << "\n"; + + auto print_row = [](const VariantResult& r) { + std::cout << std::fixed; + std::cout << " " << std::setw(14) << r.label << " | " << std::setprecision(4) + << std::setw(10) << r.time_ms << " | " << std::setprecision(2) << std::setw(10) + << r.tflops << " | " << std::scientific << std::setw(10) << r.max_abs_err << " | " + << std::setw(10) << r.max_rel_err << " | " << std::fixed << std::setw(8) + << r.errors << " | " << std::setw(6) << (r.passed ? "PASS" : "FAIL") << "\n"; + }; + + print_row(r_fp16_h128); + print_row(r_bf16_h128); + print_row(r_fp16_h64); + + // ========================================================================= + // Step 4: Tolerance Notes + // ========================================================================= + std::cout << "\nStep 4: Tolerance Notes\n"; + std::cout << " fp16 validation: rtol=1e-3, atol=1e-3 (higher precision)\n"; + std::cout << " bf16 validation: rtol=1e-2, atol=1e-2 (wider tolerance for bfloat16)\n"; + std::cout << "\n Additional dtype/hdim combinations (planning-level declarations):\n"; + std::cout << " fp32: .dtype(\"fp32\") - full single precision\n"; + std::cout << " fp8bf16: .dtype(\"fp8bf16\") - fp8 compute, bf16 output\n"; + std::cout << " fp8fp32: .dtype(\"fp8fp32\") - fp8 compute, fp32 output\n"; + std::cout << " hdim 256: .hdim(256), tile(128,128,32,256,32,256)\n"; + std::cout << " asymmetric: .hdim_q(128), .hdim_v(64) - different Q/V dims\n"; + + bool all_passed = r_fp16_h128.passed && r_bf16_h128.passed && r_fp16_h64.passed; + + print_separator(); + std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/27_padding_permutation_fmha.cpp b/dispatcher/examples/fmha/cpp/27_padding_permutation_fmha.cpp new file mode 100644 index 0000000000..5902bc7ea3 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/27_padding_permutation_fmha.cpp @@ -0,0 +1,635 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 27: Padding, Group Mode, V Col-Major, Permutation Patterns +// +// Demonstrates: +// 1. Batch padding with cu_seqlen arrays for per-batch variable lengths +// 2. Group mode with seqstart_q / seqstart_k buffers +// 3. V col-major layout declaration: .vlayout("c") +// 4. Permutation patterns: bhsd (iperm=1) vs bshd (iperm=0) strides +// 5. GPU execution with basic kernel + batch padding +// +// Mirrors 01_basic_fmha.cpp for FMHA. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(padding_permutation_kernels, + + // Basic fwd kernel (batch mode, for GPU execution) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + // Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + // Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Group mode kernel (variable-length sequences) + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("group") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // V col-major layout declaration + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("c") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + int batch, + int nhead, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen_q; ++sq) + { + std::vector scores(seqlen_k, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen_k; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim_q; ++d) + { + int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d; + int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + for(int sk = 0; sk < seqlen_k; ++sk) + { + scores[sk] /= sum_exp; + } + + for(int dv = 0; dv < hdim_v; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen_k; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 27: Padding & Permutation FMHA", + "FMHA padding, group mode, V col-major, and permutation patterns"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length (Q and K)"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_flag("--validate", "Validate against CPU reference"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 27: Padding, Group Mode, V Col-Major, Permutation"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("padding_permutation"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + // ========================================================================= + // Step 2: Batch Padding Pattern + // Allocate cu_seqlen_q / cu_seqlen_k buffers with cumulative sums. + // In CK's dispatcher, this maps to seqstart_q_ptr / seqstart_k_ptr + // and requires group mode to enable per-batch variable sequence lengths. + // ========================================================================= + std::cout << "\nStep 2: Batch Padding Pattern (cu_seqlen)\n"; + { + // Per-batch sequence lengths: batch 0 has seqlen=32, batch 1 has seqlen=48 + const std::vector seqlens_q = {32, 48}; + const std::vector seqlens_k = {32, 48}; + const int num_batches = static_cast(seqlens_q.size()); + + // Build cumulative sum arrays: [0, 32, 80] + std::vector cu_seqlen_q(num_batches + 1, 0); + std::vector cu_seqlen_k(num_batches + 1, 0); + for(int i = 0; i < num_batches; ++i) + { + cu_seqlen_q[i + 1] = cu_seqlen_q[i] + seqlens_q[i]; + cu_seqlen_k[i + 1] = cu_seqlen_k[i] + seqlens_k[i]; + } + + const int total_q = cu_seqlen_q.back(); + const int total_k = cu_seqlen_k.back(); + const int max_sq = *std::max_element(seqlens_q.begin(), seqlens_q.end()); + + std::cout << " Batch seqlens_q: ["; + for(int i = 0; i < num_batches; ++i) + std::cout << (i ? ", " : "") << seqlens_q[i]; + std::cout << "]\n"; + std::cout << " cu_seqlen_q: ["; + for(size_t i = 0; i < cu_seqlen_q.size(); ++i) + std::cout << (i ? ", " : "") << cu_seqlen_q[i]; + std::cout << "]\n"; + + GpuBuffer cu_sq_dev(num_batches + 1); + GpuBuffer cu_sk_dev(num_batches + 1); + cu_sq_dev.copy_from_host(cu_seqlen_q.data()); + cu_sk_dev.copy_from_host(cu_seqlen_k.data()); + + // Group mode traits for variable-length sequences + fmha_fwd_traits pad_traits{}; + pad_traits.hdim_q = hdim; + pad_traits.hdim_v = hdim; + pad_traits.data_type = "fp16"; + pad_traits.is_group_mode = true; + pad_traits.is_v_rowmajor = true; + pad_traits.has_logits_soft_cap = false; + pad_traits.mask_type = mask_enum::no_mask; + pad_traits.bias_type = bias_enum::no_bias; + pad_traits.has_lse = false; + pad_traits.has_dropout = false; + pad_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args pad_args{}; + pad_args.seqlen_q = total_q; + pad_args.seqlen_k = total_k; + pad_args.batch = num_batches; + pad_args.max_seqlen_q = max_sq; + pad_args.hdim_q = hdim; + pad_args.hdim_v = hdim; + pad_args.nhead_q = nhead; + pad_args.nhead_k = nhead; + pad_args.scale_s = scale; + + // cu_seqlen_q_ptr / cu_seqlen_k_ptr (seqstart_q / seqstart_k in CK) + pad_args.seqstart_q_ptr = cu_sq_dev.get(); + pad_args.seqstart_k_ptr = cu_sk_dev.get(); + + auto pad_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(pad_traits, pad_args), gfx_arch)); + std::cout << " Batch padding plan valid: " << (pad_plan.is_valid() ? "yes" : "no") << "\n"; + } + + // ========================================================================= + // Step 3: Group Mode Pattern + // Group mode uses seqstart_q / seqstart_k arrays to define variable + // sequence boundaries. Each batch element can have a different length. + // traits.is_group_mode = true + // ========================================================================= + std::cout << "\nStep 3: Group Mode Pattern (seqstart)\n"; + { + fmha_fwd_traits group_traits{}; + group_traits.hdim_q = hdim; + group_traits.hdim_v = hdim; + group_traits.data_type = "fp16"; + group_traits.is_group_mode = true; + group_traits.is_v_rowmajor = true; + group_traits.has_logits_soft_cap = false; + group_traits.mask_type = mask_enum::no_mask; + group_traits.bias_type = bias_enum::no_bias; + group_traits.has_lse = false; + group_traits.has_dropout = false; + group_traits.qscale_type = quant_scale_enum::no_scale; + + const std::vector seqstart_q = {0, 64, 192}; + const std::vector seqstart_k = {0, 128, 256}; + const int num_batches = static_cast(seqstart_q.size()) - 1; + const int total_q = seqstart_q.back(); + const int max_sq = 128; + + GpuBuffer ss_q_dev(seqstart_q.size()); + GpuBuffer ss_k_dev(seqstart_k.size()); + ss_q_dev.copy_from_host(seqstart_q.data()); + ss_k_dev.copy_from_host(seqstart_k.data()); + + fmha_fwd_args group_args{}; + group_args.seqlen_q = total_q; + group_args.seqlen_k = seqstart_k.back(); + group_args.batch = num_batches; + group_args.max_seqlen_q = max_sq; + group_args.hdim_q = hdim; + group_args.hdim_v = hdim; + group_args.nhead_q = nhead; + group_args.nhead_k = nhead; + group_args.scale_s = scale; + group_args.seqstart_q_ptr = ss_q_dev.get(); + group_args.seqstart_k_ptr = ss_k_dev.get(); + + std::cout << " seqstart_q: [0, 64, 192] -> batches of length 64 and 128\n"; + std::cout << " seqstart_k: [0, 128, 256] -> KV of length 128 and 128\n"; + + auto group_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(group_traits, group_args), gfx_arch)); + std::cout << " Group mode plan valid: " << (group_plan.is_valid() ? "yes" : "no") << "\n"; + } + + // ========================================================================= + // Step 4: V Col-Major Declaration + // .vlayout("c") declares V in column-major layout (seqlen_k x hdim_v + // stored column-first). This affects how the kernel reads V. + // ========================================================================= + std::cout << "\nStep 4: V Col-Major Layout\n"; + { + fmha_fwd_traits vcol_traits{}; + vcol_traits.hdim_q = hdim; + vcol_traits.hdim_v = hdim; + vcol_traits.data_type = "fp16"; + vcol_traits.is_group_mode = false; + vcol_traits.is_v_rowmajor = false; + vcol_traits.has_logits_soft_cap = false; + vcol_traits.mask_type = mask_enum::no_mask; + vcol_traits.bias_type = bias_enum::no_bias; + vcol_traits.has_lse = false; + vcol_traits.has_dropout = false; + vcol_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args vcol_args{}; + vcol_args.batch = batch; + vcol_args.seqlen_q = seqlen; + vcol_args.seqlen_k = seqlen; + vcol_args.max_seqlen_q = seqlen; + vcol_args.hdim_q = hdim; + vcol_args.hdim_v = hdim; + vcol_args.nhead_q = nhead; + vcol_args.nhead_k = nhead; + vcol_args.scale_s = scale; + + std::cout << " V row-major (.vlayout(\"r\")): stride_v = hdim, " + "contiguous along head dimension\n"; + std::cout << " V col-major (.vlayout(\"c\")): stride_v = seqlen_k, " + "contiguous along sequence dimension\n"; + std::cout << " traits.is_v_rowmajor = false\n"; + + auto vcol_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(vcol_traits, vcol_args), gfx_arch)); + std::cout << " V col-major plan valid: " << (vcol_plan.is_valid() ? "yes" : "no") << "\n"; + } + + // ========================================================================= + // Step 5: Permutation Patterns (bhsd vs bshd) + // bhsd layout (iperm=1): stride_q = hdim, nhead_stride_q = seqlen*hdim + // memory: [batch, head, seq, dim] + // bshd layout (iperm=0): stride_q = nhead*hdim, nhead_stride_q = hdim + // memory: [batch, seq, head, dim] + // ========================================================================= + std::cout << "\nStep 5: Permutation Patterns\n"; + { + std::cout << " bhsd layout (iperm=1):\n"; + std::cout << " stride_q = hdim = " << hdim << "\n"; + std::cout << " nhead_stride_q = seqlen * hdim = " << seqlen * hdim << "\n"; + std::cout << " batch_stride_q = nhead * seqlen * hdim = " << nhead * seqlen * hdim + << "\n"; + std::cout << " memory order: [batch, head, seq, dim]\n"; + + std::cout << "\n bshd layout (iperm=0):\n"; + std::cout << " stride_q = nhead * hdim = " << nhead * hdim << "\n"; + std::cout << " nhead_stride_q = hdim = " << hdim << "\n"; + std::cout << " batch_stride_q = seqlen * nhead * hdim = " << seqlen * nhead * hdim + << "\n"; + std::cout << " memory order: [batch, seq, head, dim]\n"; + } + + // ========================================================================= + // Step 6: GPU Execution with basic kernel + batch padding + // Run the batch-mode kernel with a non-tile-aligned seqlen to exercise + // the .padding(true, true, true, true) capability. + // ========================================================================= + std::cout << "\nStep 6: GPU Execution (batch mode, seqlen=" << seqlen << ")\n"; + + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = hdim; + fwd_traits.hdim_v = hdim; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.has_logits_soft_cap = false; + fwd_traits.mask_type = mask_enum::no_mask; + fwd_traits.bias_type = bias_enum::no_bias; + fwd_traits.has_lse = false; + fwd_traits.has_dropout = false; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + const int64_t q_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t k_elems = q_elems; + const int64_t v_elems = q_elems; + const int64_t o_elems = q_elems; + + std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim + << "]\n"; + + GpuBuffer q_dev(q_elems); + GpuBuffer k_dev(k_elems); + GpuBuffer v_dev(v_elems); + GpuBuffer o_dev(o_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(q_elems); + std::vector k_host(k_elems); + std::vector v_host(v_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + + fmha_fwd_args fwd_args{}; + fwd_args.q_ptr = q_dev.get(); + fwd_args.k_ptr = k_dev.get(); + fwd_args.v_ptr = v_dev.get(); + fwd_args.o_ptr = o_dev.get(); + + fwd_args.bias_ptr = nullptr; + fwd_args.q_descale_ptr = nullptr; + fwd_args.k_descale_ptr = nullptr; + fwd_args.v_descale_ptr = nullptr; + fwd_args.rand_val_ptr = nullptr; + fwd_args.lse_ptr = nullptr; + fwd_args.sink_ptr = nullptr; + fwd_args.block_scale_seqstart_q_ptr = nullptr; + fwd_args.block_scale_seqstart_k_ptr = nullptr; + + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.batch = batch; + fwd_args.max_seqlen_q = seqlen; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.scale_s = scale; + fwd_args.logits_soft_cap = 0.0f; + + // bhsd layout strides (iperm=1) + fwd_args.stride_q = hdim; + fwd_args.stride_k = hdim; + fwd_args.stride_v = hdim; + fwd_args.stride_bias = 0; + fwd_args.stride_randval = 0; + fwd_args.stride_o = hdim; + + fwd_args.nhead_stride_q = seqlen * hdim; + fwd_args.nhead_stride_k = seqlen * hdim; + fwd_args.nhead_stride_v = seqlen * hdim; + fwd_args.nhead_stride_bias = 0; + fwd_args.nhead_stride_randval = 0; + fwd_args.nhead_stride_lse = 0; + fwd_args.nhead_stride_o = seqlen * hdim; + fwd_args.nhead_stride_q_descale = 0; + fwd_args.nhead_stride_k_descale = 0; + fwd_args.nhead_stride_v_descale = 0; + + fwd_args.batch_stride_q = nhead * seqlen * hdim; + fwd_args.batch_stride_k = nhead * seqlen * hdim; + fwd_args.batch_stride_v = nhead * seqlen * hdim; + fwd_args.batch_stride_bias = 0; + fwd_args.batch_stride_randval = 0; + fwd_args.batch_stride_lse = 0; + fwd_args.batch_stride_o = nhead * seqlen * hdim; + fwd_args.batch_stride_q_descale = 0; + fwd_args.batch_stride_k_descale = 0; + fwd_args.batch_stride_v_descale = 0; + + fwd_args.window_size_left = -1; + fwd_args.window_size_right = -1; + fwd_args.sink_size = 0; + fwd_args.mask_type = 0; + fwd_args.min_seqlen_q = 0; + fwd_args.p_drop = 0.0f; + fwd_args.s_randval = false; + fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fwd_args.block_scale_size_q = 0; + fwd_args.block_scale_size_kv = 0; + + float time_ms = 0.0f; + try + { + time_ms = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " ERROR: " << e.what() << "\n"; + return 1; + } + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (time_ms * 1e-3) / 1e12; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 7: Validate + std::cout << "\nStep 7: Validate\n"; + std::vector o_host(o_elems); + o_dev.copy_to_host(o_host.data()); + + int nonzero = 0; + for(int64_t i = 0; i < o_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n"; + + bool passed = (nonzero > 0); + + if(args.has("--validate")) + { + std::vector q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f); + for(int64_t i = 0; i < q_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < k_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < v_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + cpu_attention_fwd( + q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale); + + double max_abs_err = 0.0; + double max_rel_err = 0.0; + int errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < o_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + double rel_err = abs_err / (std::abs(ref_val) + 1e-6); + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++errors; + } + + std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n"; + std::cout << " Max rel error: " << max_rel_err << "\n"; + std::cout << " Errors: " << errors << " / " << o_elems << "\n"; + passed = (errors == 0); + } + + print_separator(); + std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/28_bwd_masks_fmha.cpp b/dispatcher/examples/fmha/cpp/28_bwd_masks_fmha.cpp new file mode 100644 index 0000000000..f9925738e3 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/28_bwd_masks_fmha.cpp @@ -0,0 +1,489 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 28: FMHA Backward with Causal Mask +// +// Demonstrates: +// 1. Forward kernel with top_left causal mask + LSE +// 2. Backward kernel families (bwd_dot_do_o, bwd_dq_dk_dv, bwd_convert_dq) with causal mask +// 3. GPU forward execution with causal mask validation +// 4. Backward 3-stage plan display +// +// Backward kernels use planning only -- actual backward GPU execution requires +// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(bwd_masks_fmha_kernels, + // Forward: causal mask (top_left) with LSE for backward + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("top_left") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Backward stage 1: dot(dO, O) with causal mask + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + // Backward stage 2: compute dQ, dK, dV with causal mask + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + + // Backward stage 3: convert accumulated dQ from fp32 to fp16 + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd_causal(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + std::vector& LSE, + int batch, + int nhead, + int seqlen, + int hdim, + float scale) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + for(int sq = 0; sq < seqlen; ++sq) + { + std::vector scores(seqlen, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim; ++d) + { + int q_idx = ((b * nhead + h) * seqlen + sq) * hdim + d; + int k_idx = ((b * nhead + h) * seqlen + sk) * hdim + d; + dot += Q[q_idx] * K[k_idx]; + } + float s = dot * scale; + + // top_left causal: mask if sk > sq + if(sk > sq) + s = -1e30f; + + scores[sk] = s; + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + int lse_idx = (b * nhead + h) * seqlen + sq; + LSE[lse_idx] = max_score + std::log(sum_exp); + + for(int sk = 0; sk < seqlen; ++sk) + scores[sk] /= sum_exp; + + for(int dv = 0; dv < hdim; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen + sk) * hdim + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen + sq) * hdim + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 28: FMHA Backward with Masks", + "Causal mask forward (GPU) + backward plan"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + print_header("Example 28: FMHA Backward with Causal Mask"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("bwd_masks_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Step 2: Plan backward (3-stage) with causal mask + std::cout << "\nStep 2: Plan Backward (causal mask)\n"; + + fmha_bwd_traits bwd_traits{}; + bwd_traits.hdim_q = hdim; + bwd_traits.hdim_v = hdim; + bwd_traits.data_type = "fp16"; + bwd_traits.is_group_mode = false; + bwd_traits.mask_type = mask_enum::mask_top_left; + bwd_traits.bias_type = bias_enum::no_bias; + bwd_traits.has_dbias = false; + bwd_traits.has_dropout = false; + bwd_traits.is_store_randval = false; + bwd_traits.is_deterministic = false; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = batch; + bwd_args.seqlen_q = seqlen; + bwd_args.seqlen_k = seqlen; + bwd_args.max_seqlen_q = seqlen; + bwd_args.max_seqlen_k = seqlen; + bwd_args.hdim_q = hdim; + bwd_args.hdim_v = hdim; + bwd_args.nhead_q = nhead; + bwd_args.nhead_k = nhead; + + auto bwd_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch)); + + if(bwd_plan.is_valid() && bwd_plan.stages.size() >= 2) + { + std::cout << " Backward plan stages (" << bwd_plan.stages.size() << "):\n"; + for(const auto& stage : bwd_plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + } + else + { + std::cout << " Backward plan: INVALID or single-stage (expected 3 stages)\n"; + std::cout << " This is expected -- backward planning shows the pattern\n"; + } + + // Step 3: Run forward on GPU with causal mask + std::cout << "\nStep 3: Run Forward (causal mask, GPU)\n"; + + const int64_t qkv_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t lse_elems = static_cast(batch) * nhead * seqlen; + + GpuBuffer q_dev(qkv_elems); + GpuBuffer k_dev(qkv_elems); + GpuBuffer v_dev(qkv_elems); + GpuBuffer o_dev(qkv_elems); + GpuBuffer lse_dev(lse_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + lse_dev.zero(); + + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = hdim; + fwd_traits.hdim_v = hdim; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.has_logits_soft_cap = false; + fwd_traits.mask_type = mask_enum::mask_top_left; + fwd_traits.bias_type = bias_enum::no_bias; + fwd_traits.has_lse = true; + fwd_traits.has_dropout = false; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.q_ptr = q_dev.get(); + fwd_args.k_ptr = k_dev.get(); + fwd_args.v_ptr = v_dev.get(); + fwd_args.o_ptr = o_dev.get(); + fwd_args.lse_ptr = lse_dev.get(); + + fwd_args.bias_ptr = nullptr; + fwd_args.q_descale_ptr = nullptr; + fwd_args.k_descale_ptr = nullptr; + fwd_args.v_descale_ptr = nullptr; + fwd_args.rand_val_ptr = nullptr; + fwd_args.sink_ptr = nullptr; + fwd_args.block_scale_seqstart_q_ptr = nullptr; + fwd_args.block_scale_seqstart_k_ptr = nullptr; + + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.batch = batch; + fwd_args.max_seqlen_q = seqlen; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.scale_s = scale; + fwd_args.logits_soft_cap = 0.0f; + + fwd_args.stride_q = hdim; + fwd_args.stride_k = hdim; + fwd_args.stride_v = hdim; + fwd_args.stride_bias = 0; + fwd_args.stride_randval = 0; + fwd_args.stride_o = hdim; + + fwd_args.nhead_stride_q = seqlen * hdim; + fwd_args.nhead_stride_k = seqlen * hdim; + fwd_args.nhead_stride_v = seqlen * hdim; + fwd_args.nhead_stride_bias = 0; + fwd_args.nhead_stride_randval = 0; + fwd_args.nhead_stride_lse = seqlen; + fwd_args.nhead_stride_o = seqlen * hdim; + fwd_args.nhead_stride_q_descale = 0; + fwd_args.nhead_stride_k_descale = 0; + fwd_args.nhead_stride_v_descale = 0; + + fwd_args.batch_stride_q = nhead * seqlen * hdim; + fwd_args.batch_stride_k = nhead * seqlen * hdim; + fwd_args.batch_stride_v = nhead * seqlen * hdim; + fwd_args.batch_stride_bias = 0; + fwd_args.batch_stride_randval = 0; + fwd_args.batch_stride_lse = nhead * seqlen; + fwd_args.batch_stride_o = nhead * seqlen * hdim; + fwd_args.batch_stride_q_descale = 0; + fwd_args.batch_stride_k_descale = 0; + fwd_args.batch_stride_v_descale = 0; + + fwd_args.window_size_left = -1; + fwd_args.window_size_right = 0; + fwd_args.sink_size = 0; + fwd_args.mask_type = 1; // top_left + fwd_args.min_seqlen_q = 0; + fwd_args.p_drop = 0.0f; + fwd_args.s_randval = false; + fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fwd_args.block_scale_size_q = 0; + fwd_args.block_scale_size_kv = 0; + + bool fwd_passed = false; + try + { + float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr); + std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time + << " ms\n"; + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (fwd_time * 1e-3) / 1e12; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + fwd_passed = true; + } + catch(const std::exception& e) + { + std::cerr << " Forward ERROR: " << e.what() << "\n"; + } + + // Step 4: Validate forward output + std::cout << "\nStep 4: Validate Forward Output\n"; + + if(fwd_passed) + { + std::vector o_host(qkv_elems); + o_dev.copy_to_host(o_host.data()); + + std::vector lse_host(lse_elems); + lse_dev.copy_to_host(lse_host.data()); + + std::vector q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems); + for(int64_t i = 0; i < qkv_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < qkv_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < qkv_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + std::vector o_ref(qkv_elems, 0.0f); + std::vector lse_ref(lse_elems, 0.0f); + cpu_attention_fwd_causal( + q_f32, k_f32, v_f32, o_ref, lse_ref, batch, nhead, seqlen, hdim, scale); + + double max_o_err = 0.0; + int o_errors = 0; + const double rtol = 1e-2; + const double atol = 1e-2; + + for(int64_t i = 0; i < qkv_elems; ++i) + { + float gpu_val = static_cast(o_host[i]); + float ref_val = o_ref[i]; + double abs_err = std::abs(gpu_val - ref_val); + max_o_err = std::max(max_o_err, abs_err); + if(abs_err > atol + rtol * std::abs(ref_val)) + ++o_errors; + } + + double max_lse_err = 0.0; + int lse_reasonable = 0; + for(int64_t i = 0; i < lse_elems; ++i) + { + if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f) + ++lse_reasonable; + max_lse_err = + std::max(max_lse_err, static_cast(std::abs(lse_host[i] - lse_ref[i]))); + } + + std::cout << " Output max abs error: " << std::scientific << max_o_err << "\n"; + std::cout << " Output errors: " << o_errors << " / " << qkv_elems << "\n"; + std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n"; + std::cout << " LSE max error: " << std::scientific << max_lse_err << "\n"; + + fwd_passed = (o_errors == 0) && (lse_reasonable == lse_elems); + } + + // Step 5: Show backward API pattern + std::cout << "\nStep 5: Backward API Pattern (traits + args)\n"; + std::cout << " bwd_traits.mask_type = mask_top_left\n"; + std::cout << " bwd_traits.bias_type = no_bias\n"; + std::cout << " bwd_traits.has_dropout = false\n"; + std::cout << " bwd_traits.is_deterministic = false\n"; + std::cout << " bwd_args.window_size_left = -1\n"; + std::cout << " bwd_args.window_size_right = 0 (causal)\n"; + std::cout << " bwd_args.mask_type = 1 (top_left)\n"; + std::cout << " Backward plan resolves to " << bwd_plan.stages.size() << " stage(s)\n"; + + print_separator(); + std::cout << "Status: " << (fwd_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return fwd_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/29_bwd_bias_dropout_fmha.cpp b/dispatcher/examples/fmha/cpp/29_bwd_bias_dropout_fmha.cpp new file mode 100644 index 0000000000..856fe553d8 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/29_bwd_bias_dropout_fmha.cpp @@ -0,0 +1,615 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 29: FMHA Backward with ALiBi Bias + Dropout +// +// Demonstrates: +// 1. Forward kernel with alibi bias + dropout + LSE +// 2. Backward kernel families with alibi + dropout +// 3. GPU forward execution with alibi bias, validates output +// 4. Backward plan with all features enabled +// 5. How deterministic mode affects the backward plan +// +// Backward kernels use planning only -- actual backward GPU execution requires +// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(bwd_bias_dropout_fmha_kernels, + // Forward: alibi bias + dropout + LSE + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("alibi") + .lse(true) + .dropout(true) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Backward stage 1: dot(dO, O) with alibi + dropout (non-deterministic) + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("alibi") + .dropout(true) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + // Backward stage 2: dQ, dK, dV with alibi + dropout (non-deterministic) + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("alibi") + .dropout(true) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + + // Backward stage 3: convert dQ with alibi + dropout (non-deterministic) + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("alibi") + .dropout(true) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + // Deterministic variants for comparison + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("alibi") + .dropout(true) + .dbias(false) + .store_randval(false) + .deterministic(true), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("alibi") + .dropout(true) + .dbias(false) + .store_randval(false) + .deterministic(true), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("alibi") + .dropout(true) + .dbias(false) + .store_randval(false) + .deterministic(true), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +void cpu_attention_fwd_alibi(const std::vector& Q, + const std::vector& K, + const std::vector& V, + std::vector& O, + std::vector& LSE, + int batch, + int nhead, + int seqlen, + int hdim, + float scale, + const std::vector& alibi_slopes) +{ + for(int b = 0; b < batch; ++b) + { + for(int h = 0; h < nhead; ++h) + { + const float slope = alibi_slopes[h]; + + for(int sq = 0; sq < seqlen; ++sq) + { + std::vector scores(seqlen, 0.0f); + float max_score = -1e30f; + + for(int sk = 0; sk < seqlen; ++sk) + { + float dot = 0.0f; + for(int d = 0; d < hdim; ++d) + { + int q_idx = ((b * nhead + h) * seqlen + sq) * hdim + d; + int k_idx = ((b * nhead + h) * seqlen + sk) * hdim + d; + dot += Q[q_idx] * K[k_idx]; + } + scores[sk] = dot * scale + slope * static_cast(sk - sq); + max_score = std::max(max_score, scores[sk]); + } + + float sum_exp = 0.0f; + for(int sk = 0; sk < seqlen; ++sk) + { + scores[sk] = std::exp(scores[sk] - max_score); + sum_exp += scores[sk]; + } + + int lse_idx = (b * nhead + h) * seqlen + sq; + LSE[lse_idx] = max_score + std::log(sum_exp); + + for(int sk = 0; sk < seqlen; ++sk) + scores[sk] /= sum_exp; + + for(int dv = 0; dv < hdim; ++dv) + { + float acc = 0.0f; + for(int sk = 0; sk < seqlen; ++sk) + { + int v_idx = ((b * nhead + h) * seqlen + sk) * hdim + dv; + acc += scores[sk] * V[v_idx]; + } + int o_idx = ((b * nhead + h) * seqlen + sq) * hdim + dv; + O[o_idx] = acc; + } + } + } + } +} + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 29: FMHA Backward with Bias + Dropout", + "ALiBi bias + dropout forward (GPU) + backward plan with deterministic mode"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "64", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 64); + const int hdim = args.get_int("--hdim", 128); + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + print_header("Example 29: FMHA Backward with ALiBi Bias + Dropout"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("bwd_bias_dropout_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 3); + + // Step 2: Plan backward (non-deterministic) with alibi + dropout + std::cout << "\nStep 2: Plan Backward (non-deterministic, alibi + dropout)\n"; + + fmha_bwd_traits bwd_traits{}; + bwd_traits.hdim_q = hdim; + bwd_traits.hdim_v = hdim; + bwd_traits.data_type = "fp16"; + bwd_traits.is_group_mode = false; + bwd_traits.mask_type = mask_enum::no_mask; + bwd_traits.bias_type = bias_enum::alibi; + bwd_traits.has_dbias = false; + bwd_traits.has_dropout = true; + bwd_traits.is_store_randval = false; + bwd_traits.is_deterministic = false; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = batch; + bwd_args.seqlen_q = seqlen; + bwd_args.seqlen_k = seqlen; + bwd_args.max_seqlen_q = seqlen; + bwd_args.max_seqlen_k = seqlen; + bwd_args.hdim_q = hdim; + bwd_args.hdim_v = hdim; + bwd_args.nhead_q = nhead; + bwd_args.nhead_k = nhead; + + auto nondet_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch)); + + if(nondet_plan.is_valid() && nondet_plan.stages.size() >= 2) + { + std::cout << " Non-deterministic plan stages (" << nondet_plan.stages.size() << "):\n"; + for(const auto& stage : nondet_plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + } + else + { + std::cout << " Non-deterministic plan: INVALID or single-stage\n"; + } + + // Step 2b: Plan backward (deterministic) with alibi + dropout + std::cout << "\nStep 2b: Plan Backward (deterministic, alibi + dropout)\n"; + + fmha_bwd_traits det_traits = bwd_traits; + det_traits.is_deterministic = true; + + auto det_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch)); + + if(det_plan.is_valid() && det_plan.stages.size() >= 2) + { + std::cout << " Deterministic plan stages (" << det_plan.stages.size() << "):\n"; + for(const auto& stage : det_plan.stages) + { + std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n"; + } + } + else + { + std::cout << " Deterministic plan: INVALID or single-stage\n"; + } + + std::cout << "\n Deterministic mode difference:\n"; + std::cout << " Non-det: dQ accumulated via atomic adds (faster, non-reproducible)\n"; + std::cout << " Det: dQ accumulated with split-stride (slower, bit-reproducible)\n"; + + // Step 3: Run forward on GPU with alibi bias + dropout + std::cout << "\nStep 3: Run Forward (alibi + dropout, GPU)\n"; + + const int64_t qkv_elems = static_cast(batch) * nhead * seqlen * hdim; + const int64_t lse_elems = static_cast(batch) * nhead * seqlen; + const int64_t randval_elems = static_cast(batch) * nhead * seqlen * seqlen; + + GpuBuffer q_dev(qkv_elems); + GpuBuffer k_dev(qkv_elems); + GpuBuffer v_dev(qkv_elems); + GpuBuffer o_dev(qkv_elems); + GpuBuffer lse_dev(lse_elems); + GpuBuffer rand_val_dev(randval_elems); + + // ALiBi slopes: geometric series + std::vector alibi_slopes_host(nhead); + for(int h = 0; h < nhead; ++h) + alibi_slopes_host[h] = -std::pow(2.0f, -(8.0f * (h + 1) / nhead)); + + GpuBuffer alibi_slopes_dev(nhead); + alibi_slopes_dev.copy_from_host(alibi_slopes_host.data()); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + o_dev.zero(); + lse_dev.zero(); + rand_val_dev.zero(); + + std::cout << " ALiBi slopes: ["; + for(int h = 0; h < nhead; ++h) + { + if(h > 0) + std::cout << ", "; + std::cout << std::fixed << std::setprecision(4) << alibi_slopes_host[h]; + } + std::cout << "]\n"; + + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = hdim; + fwd_traits.hdim_v = hdim; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.has_logits_soft_cap = false; + fwd_traits.mask_type = mask_enum::no_mask; + fwd_traits.bias_type = bias_enum::alibi; + fwd_traits.has_lse = true; + fwd_traits.has_dropout = true; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.q_ptr = q_dev.get(); + fwd_args.k_ptr = k_dev.get(); + fwd_args.v_ptr = v_dev.get(); + fwd_args.o_ptr = o_dev.get(); + fwd_args.lse_ptr = lse_dev.get(); + + fwd_args.bias_ptr = alibi_slopes_dev.get(); + fwd_args.rand_val_ptr = rand_val_dev.get(); + fwd_args.q_descale_ptr = nullptr; + fwd_args.k_descale_ptr = nullptr; + fwd_args.v_descale_ptr = nullptr; + fwd_args.sink_ptr = nullptr; + fwd_args.block_scale_seqstart_q_ptr = nullptr; + fwd_args.block_scale_seqstart_k_ptr = nullptr; + + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.batch = batch; + fwd_args.max_seqlen_q = seqlen; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.scale_s = scale; + fwd_args.logits_soft_cap = 0.0f; + + fwd_args.stride_q = hdim; + fwd_args.stride_k = hdim; + fwd_args.stride_v = hdim; + fwd_args.stride_bias = 0; // alibi: per-head slope, no spatial stride + fwd_args.stride_randval = seqlen; + fwd_args.stride_o = hdim; + + fwd_args.nhead_stride_q = seqlen * hdim; + fwd_args.nhead_stride_k = seqlen * hdim; + fwd_args.nhead_stride_v = seqlen * hdim; + fwd_args.nhead_stride_bias = 1; // alibi: stride between slopes + fwd_args.nhead_stride_randval = seqlen * seqlen; + fwd_args.nhead_stride_lse = seqlen; + fwd_args.nhead_stride_o = seqlen * hdim; + fwd_args.nhead_stride_q_descale = 0; + fwd_args.nhead_stride_k_descale = 0; + fwd_args.nhead_stride_v_descale = 0; + + fwd_args.batch_stride_q = nhead * seqlen * hdim; + fwd_args.batch_stride_k = nhead * seqlen * hdim; + fwd_args.batch_stride_v = nhead * seqlen * hdim; + fwd_args.batch_stride_bias = 0; // alibi slopes shared across batch + fwd_args.batch_stride_randval = nhead * seqlen * seqlen; + fwd_args.batch_stride_lse = nhead * seqlen; + fwd_args.batch_stride_o = nhead * seqlen * hdim; + fwd_args.batch_stride_q_descale = 0; + fwd_args.batch_stride_k_descale = 0; + fwd_args.batch_stride_v_descale = 0; + + fwd_args.window_size_left = -1; + fwd_args.window_size_right = -1; + fwd_args.sink_size = 0; + fwd_args.mask_type = 0; + fwd_args.min_seqlen_q = 0; + fwd_args.p_drop = 0.2f; + fwd_args.s_randval = true; + fwd_args.drop_seed_offset = std::make_pair(uint64_t(42), uint64_t(0)); + fwd_args.block_scale_size_q = 0; + fwd_args.block_scale_size_kv = 0; + + bool fwd_passed = false; + try + { + float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr); + std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time + << " ms\n"; + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch); + double tflops = static_cast(problem.num_ops()) / (fwd_time * 1e-3) / 1e12; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + fwd_passed = true; + } + catch(const std::exception& e) + { + std::cerr << " Forward ERROR: " << e.what() << "\n"; + } + + // Step 4: Validate forward output (without dropout reference -- just check non-zero + LSE) + std::cout << "\nStep 4: Validate Forward Output\n"; + + if(fwd_passed) + { + std::vector o_host(qkv_elems); + o_dev.copy_to_host(o_host.data()); + + int nonzero = 0; + for(int64_t i = 0; i < qkv_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + std::cout << " Non-zero outputs: " << nonzero << " / " << qkv_elems << "\n"; + + std::vector lse_host(lse_elems); + lse_dev.copy_to_host(lse_host.data()); + + int lse_reasonable = 0; + for(int64_t i = 0; i < lse_elems; ++i) + { + if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f) + ++lse_reasonable; + } + std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n"; + + std::cout << " LSE sample [0..3]: "; + for(int i = 0; i < std::min(4, lse_elems); ++i) + std::cout << std::fixed << std::setprecision(4) << lse_host[i] << " "; + std::cout << "\n"; + + fwd_passed = (nonzero > 0) && (lse_reasonable == lse_elems); + + // ALiBi reference (without dropout) for sanity check on bias effect + std::vector q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems); + for(int64_t i = 0; i < qkv_elems; ++i) + q_f32[i] = static_cast(q_host[i]); + for(int64_t i = 0; i < qkv_elems; ++i) + k_f32[i] = static_cast(k_host[i]); + for(int64_t i = 0; i < qkv_elems; ++i) + v_f32[i] = static_cast(v_host[i]); + + std::vector o_ref(qkv_elems, 0.0f); + std::vector lse_ref(lse_elems, 0.0f); + cpu_attention_fwd_alibi(q_f32, + k_f32, + v_f32, + o_ref, + lse_ref, + batch, + nhead, + seqlen, + hdim, + scale, + alibi_slopes_host); + + // LSE should be close (dropout doesn't change LSE in the CK implementation -- + // LSE is computed before dropout is applied to the attention weights) + double max_lse_err = 0.0; + for(int64_t i = 0; i < lse_elems; ++i) + max_lse_err = + std::max(max_lse_err, static_cast(std::abs(lse_host[i] - lse_ref[i]))); + + std::cout << " LSE vs alibi ref (no dropout) max error: " << std::scientific << max_lse_err + << "\n"; + } + + // Step 5: Show backward API pattern with all features + std::cout << "\nStep 5: Backward API Pattern (all features)\n"; + std::cout << " bwd_traits.bias_type = alibi\n"; + std::cout << " bwd_traits.has_dropout = true\n"; + std::cout << " bwd_traits.is_store_randval = false\n"; + std::cout << " bwd_traits.has_dbias = false (alibi has no learnable params)\n"; + std::cout << "\n Non-deterministic plan: " << nondet_plan.stages.size() << " stage(s)\n"; + std::cout << " Deterministic plan: " << det_plan.stages.size() << " stage(s)\n"; + std::cout << "\n Key backward args for dropout:\n"; + std::cout << " bwd_args.p_drop = 0.2\n"; + std::cout << " bwd_args.p_undrop = 1.0 / (1.0 - p_drop) = 1.25\n"; + std::cout << " bwd_args.drop_seed_offset = {42, 0} (must match fwd)\n"; + + print_separator(); + std::cout << "Status: " << (fwd_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return fwd_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/30_bwd_benchmark_fmha.cpp b/dispatcher/examples/fmha/cpp/30_bwd_benchmark_fmha.cpp new file mode 100644 index 0000000000..ea26f2f085 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/30_bwd_benchmark_fmha.cpp @@ -0,0 +1,449 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 30: FMHA Backward Benchmark +// +// Demonstrates: +// 1. Forward kernel for benchmark (with LSE for backward planning) +// 2. Multiple problem sizes: sweep batch x seqlen +// 3. GPU forward execution for each size with timing +// 4. Backward plan for each size +// 5. Summary table: Batch | SeqLen | Fwd(ms) | BwdPlan | FwdTFLOPS +// +// Backward kernels use planning only -- actual backward GPU execution requires +// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(bwd_bench_fmha_kernels, + // Forward: basic fp16 with LSE for backward + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + + // Backward stage 1: dot(dO, O) + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + + // Backward stage 2: dQ, dK, dV + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + + // Backward stage 3: convert dQ + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +namespace { + +using FmhaDataType = ck_tile::fp16_t; + +struct BenchResult +{ + int batch; + int seqlen; + float fwd_ms; + double fwd_tflops; + int bwd_stages; + bool bwd_valid; + bool fwd_passed; +}; + +} // namespace + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 30: FMHA Backward Benchmark", + "Sweep batch x seqlen, forward GPU + backward plan"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--nhead", "8", "Number of heads"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_option("--warmup", "2", "Warmup iterations per size"); + args.add_option("--repeat", "3", "Benchmark repetitions per size"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int nhead = args.get_int("--nhead", 8); + const int hdim = args.get_int("--hdim", 128); + const int warmup = args.get_int("--warmup", 2); + const int repeat = args.get_int("--repeat", 3); + const float scale = 1.0f / std::sqrt(static_cast(hdim)); + + print_header("Example 30: FMHA Backward Benchmark"); + + // Step 1: Register kernels + std::cout << "\nStep 1: Register Kernels\n"; + FmhaKernelSetRegistry::instance().print(); + + FmhaRegistry registry; + registry.set_name("bwd_bench_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + FmhaDispatcher dispatcher(®istry); + + // Problem sizes to sweep + struct ProblemSize + { + int batch; + int seqlen; + }; + + ProblemSize sizes[] = { + {8, 128}, + {4, 256}, + {2, 512}, + {1, 1024}, + {1, 2048}, + {1, 4096}, + }; + + std::vector results; + + // Step 2: Sweep problem sizes + std::cout << "\nStep 2: Sweep Problem Sizes\n"; + + for(const auto& sz : sizes) + { + std::cout << "\n --- batch=" << sz.batch << ", seqlen=" << sz.seqlen << " ---\n"; + + const int64_t qkv_elems = static_cast(sz.batch) * nhead * sz.seqlen * hdim; + const int64_t lse_elems = static_cast(sz.batch) * nhead * sz.seqlen; + + BenchResult res{}; + res.batch = sz.batch; + res.seqlen = sz.seqlen; + + // Allocate buffers + GpuBuffer q_dev(qkv_elems); + GpuBuffer k_dev(qkv_elems); + GpuBuffer v_dev(qkv_elems); + GpuBuffer o_dev(qkv_elems); + GpuBuffer lse_dev(lse_elems); + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + std::vector q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems); + for(auto& x : q_host) + x = FmhaDataType(dist(rng)); + for(auto& x : k_host) + x = FmhaDataType(dist(rng)); + for(auto& x : v_host) + x = FmhaDataType(dist(rng)); + + q_dev.copy_from_host(q_host.data()); + k_dev.copy_from_host(k_host.data()); + v_dev.copy_from_host(v_host.data()); + + // Forward traits/args + fmha_fwd_traits fwd_traits{}; + fwd_traits.hdim_q = hdim; + fwd_traits.hdim_v = hdim; + fwd_traits.data_type = "fp16"; + fwd_traits.is_group_mode = false; + fwd_traits.is_v_rowmajor = true; + fwd_traits.has_logits_soft_cap = false; + fwd_traits.mask_type = mask_enum::no_mask; + fwd_traits.bias_type = bias_enum::no_bias; + fwd_traits.has_lse = true; + fwd_traits.has_dropout = false; + fwd_traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.q_ptr = q_dev.get(); + fwd_args.k_ptr = k_dev.get(); + fwd_args.v_ptr = v_dev.get(); + fwd_args.o_ptr = o_dev.get(); + fwd_args.lse_ptr = lse_dev.get(); + + fwd_args.bias_ptr = nullptr; + fwd_args.q_descale_ptr = nullptr; + fwd_args.k_descale_ptr = nullptr; + fwd_args.v_descale_ptr = nullptr; + fwd_args.rand_val_ptr = nullptr; + fwd_args.sink_ptr = nullptr; + fwd_args.block_scale_seqstart_q_ptr = nullptr; + fwd_args.block_scale_seqstart_k_ptr = nullptr; + + fwd_args.seqlen_q = sz.seqlen; + fwd_args.seqlen_k = sz.seqlen; + fwd_args.batch = sz.batch; + fwd_args.max_seqlen_q = sz.seqlen; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.scale_s = scale; + fwd_args.logits_soft_cap = 0.0f; + + fwd_args.stride_q = hdim; + fwd_args.stride_k = hdim; + fwd_args.stride_v = hdim; + fwd_args.stride_bias = 0; + fwd_args.stride_randval = 0; + fwd_args.stride_o = hdim; + + fwd_args.nhead_stride_q = sz.seqlen * hdim; + fwd_args.nhead_stride_k = sz.seqlen * hdim; + fwd_args.nhead_stride_v = sz.seqlen * hdim; + fwd_args.nhead_stride_bias = 0; + fwd_args.nhead_stride_randval = 0; + fwd_args.nhead_stride_lse = sz.seqlen; + fwd_args.nhead_stride_o = sz.seqlen * hdim; + fwd_args.nhead_stride_q_descale = 0; + fwd_args.nhead_stride_k_descale = 0; + fwd_args.nhead_stride_v_descale = 0; + + fwd_args.batch_stride_q = nhead * sz.seqlen * hdim; + fwd_args.batch_stride_k = nhead * sz.seqlen * hdim; + fwd_args.batch_stride_v = nhead * sz.seqlen * hdim; + fwd_args.batch_stride_bias = 0; + fwd_args.batch_stride_randval = 0; + fwd_args.batch_stride_lse = nhead * sz.seqlen; + fwd_args.batch_stride_o = nhead * sz.seqlen * hdim; + fwd_args.batch_stride_q_descale = 0; + fwd_args.batch_stride_k_descale = 0; + fwd_args.batch_stride_v_descale = 0; + + fwd_args.window_size_left = -1; + fwd_args.window_size_right = -1; + fwd_args.sink_size = 0; + fwd_args.mask_type = 0; + fwd_args.min_seqlen_q = 0; + fwd_args.p_drop = 0.0f; + fwd_args.s_randval = false; + fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0)); + fwd_args.block_scale_size_q = 0; + fwd_args.block_scale_size_kv = 0; + + // Warmup + dispatcher.set_benchmarking(true); + dispatcher.set_timing(1, 1); + try + { + for(int w = 0; w < warmup; ++w) + { + o_dev.zero(); + lse_dev.zero(); + dispatcher.run_fwd(fwd_traits, fwd_args, nullptr); + } + } + catch(const std::exception& e) + { + std::cerr << " Warmup ERROR: " << e.what() << "\n"; + res.fwd_passed = false; + results.push_back(res); + continue; + } + + // Benchmark + dispatcher.set_timing(0, 1); + float total_ms = 0.0f; + bool ok = true; + for(int r = 0; r < repeat; ++r) + { + o_dev.zero(); + lse_dev.zero(); + try + { + total_ms += dispatcher.run_fwd(fwd_traits, fwd_args, nullptr); + } + catch(const std::exception& e) + { + std::cerr << " Bench ERROR: " << e.what() << "\n"; + ok = false; + break; + } + } + + if(ok) + { + res.fwd_ms = total_ms / static_cast(repeat); + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch); + res.fwd_tflops = static_cast(problem.num_ops()) / (res.fwd_ms * 1e-3) / 1e12; + + // Sanity check output + std::vector o_host(qkv_elems); + o_dev.copy_to_host(o_host.data()); + int nonzero = 0; + for(int64_t i = 0; i < qkv_elems; ++i) + { + if(static_cast(o_host[i]) != 0.0f) + ++nonzero; + } + res.fwd_passed = (nonzero > 0); + } + else + { + res.fwd_passed = false; + } + + // Backward plan for this size + fmha_bwd_traits bwd_traits{}; + bwd_traits.hdim_q = hdim; + bwd_traits.hdim_v = hdim; + bwd_traits.data_type = "fp16"; + bwd_traits.is_group_mode = false; + bwd_traits.mask_type = mask_enum::no_mask; + bwd_traits.bias_type = bias_enum::no_bias; + bwd_traits.has_dbias = false; + bwd_traits.has_dropout = false; + bwd_traits.is_store_randval = false; + bwd_traits.is_deterministic = false; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = sz.batch; + bwd_args.seqlen_q = sz.seqlen; + bwd_args.seqlen_k = sz.seqlen; + bwd_args.max_seqlen_q = sz.seqlen; + bwd_args.max_seqlen_k = sz.seqlen; + bwd_args.hdim_q = hdim; + bwd_args.hdim_v = hdim; + bwd_args.nhead_q = nhead; + bwd_args.nhead_k = nhead; + + auto bwd_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch)); + + res.bwd_valid = bwd_plan.is_valid() && bwd_plan.stages.size() >= 2; + res.bwd_stages = static_cast(bwd_plan.stages.size()); + + std::cout << " Fwd: " << std::fixed << std::setprecision(4) << res.fwd_ms << " ms, " + << std::setprecision(2) << res.fwd_tflops << " TFLOPS" + << " | Bwd plan: " << res.bwd_stages << " stages" + << (res.bwd_valid ? " (valid)" : " (invalid)") << "\n"; + + results.push_back(res); + } + + // Step 3: Summary table + std::cout << "\nStep 3: Summary\n\n"; + std::cout << " " << std::setw(7) << "Batch" << " | " << std::setw(7) << "SeqLen" << " | " + << std::setw(10) << "Fwd(ms)" << " | " << std::setw(8) << "BwdPlan" << " | " + << std::setw(10) << "FwdTFLOPS" << " | " << std::setw(6) << "Status" << "\n"; + std::cout << " " << std::string(60, '-') << "\n"; + + bool all_passed = true; + for(const auto& r : results) + { + std::cout << " " << std::setw(7) << r.batch << " | " << std::setw(7) << r.seqlen << " | " + << std::fixed << std::setprecision(4) << std::setw(10) << r.fwd_ms << " | " + << std::setw(5) << r.bwd_stages << "stg" << " | " << std::setprecision(2) + << std::setw(10) << r.fwd_tflops << " | " << std::setw(6) + << (r.fwd_passed ? "PASS" : "FAIL") << "\n"; + if(!r.fwd_passed) + all_passed = false; + } + + print_separator(); + std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/fmha/cpp/31_logits_soft_cap_fmha.cpp b/dispatcher/examples/fmha/cpp/31_logits_soft_cap_fmha.cpp new file mode 100644 index 0000000000..43172d7778 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/31_logits_soft_cap_fmha.cpp @@ -0,0 +1,118 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 31: FMHA Forward with Logits Soft Cap +// +// Demonstrates forward kernel with logits_soft_cap enabled. The soft cap +// applies: scores_capped = tanh(scores/cap) * cap, which prevents extreme +// attention logits from causing numerical instability while preserving +// gradients. Planning only. + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(logits_soft_cap_fmha_kernels, + // Forward with logits soft cap: tanh(scores/cap)*cap + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no") + .lse(false) + .dropout(false) + .qscale("no") + .logits(true), // enables logits_soft_cap path + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 31: FMHA Logits Soft Cap", "Forward with tanh(scores/cap)*cap"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "128", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 31: FMHA Logits Soft Cap"); + + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + registry.set_name("logits_soft_cap_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + std::cout << "\nStep 2: Plan\n"; + FmhaDispatcher dispatcher(®istry); + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = true; // runtime: cap > 0 means soft cap applied + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.batch = batch; + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.logits_soft_cap = 30.0f; // cap value; apply tanh(scores/30)*30 + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch)); + std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n"; + + std::cout << "\nStep 3: Logits Soft Cap\n"; + std::cout << " Formula: scores_capped = tanh(scores/cap) * cap\n"; + std::cout << " Prevents extreme logits while preserving gradients.\n"; + + print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/32_sink_tokens_fmha.cpp b/dispatcher/examples/fmha/cpp/32_sink_tokens_fmha.cpp new file mode 100644 index 0000000000..5f62e1ba0b --- /dev/null +++ b/dispatcher/examples/fmha/cpp/32_sink_tokens_fmha.cpp @@ -0,0 +1,119 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 32: FMHA Forward with Sink Tokens +// +// Demonstrates forward kernel with sink tokens enabled. Sink tokens keep the +// first K positions always visible to all queries (StreamingLLM-style). Used +// with causal mask: positions [0, sink_size) are never masked. Planning only. + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(sink_tokens_fmha_kernels, + // Forward with sink: first K tokens always visible + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("top_left") // causal required with sink + .bias("no") + .lse(false) + .dropout(false) + .qscale("no") + .sink(true), // enables sink token path + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 32: FMHA Sink Tokens", "Forward with first K tokens always visible"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "128", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_option("--sink", "4", "Number of sink tokens"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + const int sink_size = args.get_int("--sink", 4); + + print_header("Example 32: FMHA Sink Tokens"); + + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + registry.set_name("sink_tokens_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + std::cout << "\nStep 2: Plan\n"; + FmhaDispatcher dispatcher(®istry); + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_sink = true; + traits.mask_type = mask_enum::mask_top_left; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.batch = batch; + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.sink_size = sink_size; + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch)); + std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n"; + + std::cout << "\nStep 3: Sink Tokens\n"; + std::cout << " First " << sink_size << " tokens always visible to all queries.\n"; + std::cout << " Used with causal mask for StreamingLLM-style long-context.\n"; + + print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/33_bwd_deterministic_fmha.cpp b/dispatcher/examples/fmha/cpp/33_bwd_deterministic_fmha.cpp new file mode 100644 index 0000000000..0f9668a6f8 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/33_bwd_deterministic_fmha.cpp @@ -0,0 +1,256 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 33: FMHA Backward Deterministic vs Non-Deterministic +// +// Demonstrates two backward kernel sets: one deterministic (bit-identical +// results across runs) and one non-deterministic (faster, atomic reductions). +// Planning only. + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(bwd_deterministic_fmha_kernels, + // Forward: causal + LSE for backward + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("top_left") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + // Backward: deterministic (bit-identical across runs) + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(true), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(true), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(true), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + // Backward: non-deterministic (faster, atomic reductions) + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(1), + "gfx950") + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(1), + "gfx950") + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(1), + "gfx950")); + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 33: FMHA Backward Deterministic", + "Deterministic vs non-deterministic backward"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "128", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 33: FMHA Backward Deterministic"); + + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + registry.set_name("bwd_deterministic_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + std::cout << "\nStep 2: Plan (deterministic)\n"; + FmhaDispatcher dispatcher(®istry); + fmha_bwd_traits det_traits{}; + det_traits.hdim_q = hdim; + det_traits.hdim_v = hdim; + det_traits.data_type = "fp16"; + det_traits.is_group_mode = false; + det_traits.mask_type = mask_enum::mask_top_left; + det_traits.bias_type = bias_enum::no_bias; + det_traits.has_dbias = false; + det_traits.has_dropout = false; + det_traits.is_store_randval = false; + det_traits.is_deterministic = true; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = batch; + bwd_args.seqlen_q = seqlen; + bwd_args.seqlen_k = seqlen; + bwd_args.hdim_q = hdim; + bwd_args.hdim_v = hdim; + bwd_args.nhead_q = nhead; + bwd_args.nhead_k = nhead; + + auto det_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch)); + std::cout << " Deterministic plan valid: " << (det_plan.is_valid() ? "yes" : "no") << "\n"; + + std::cout << "\nStep 3: Plan (non-deterministic)\n"; + det_traits.is_deterministic = false; + auto nondet_plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch)); + std::cout << " Non-deterministic plan valid: " << (nondet_plan.is_valid() ? "yes" : "no") + << "\n"; + + std::cout << "\nStep 4: Deterministic Mode\n"; + std::cout << " deterministic=true: bit-identical across runs (reproducible).\n"; + std::cout << " deterministic=false: faster, uses atomic reductions.\n"; + + print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/34_bwd_gqa_fmha.cpp b/dispatcher/examples/fmha/cpp/34_bwd_gqa_fmha.cpp new file mode 100644 index 0000000000..d2b592e0a7 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/34_bwd_gqa_fmha.cpp @@ -0,0 +1,183 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 34: FMHA Backward with GQA (Grouped Query Attention) +// +// Demonstrates backward with nhead_q=8, nhead_k=2 (4:1 ratio). GQA is a +// runtime property: each KV head is shared by multiple Q heads. Backward +// handles head indexing via nhead_stride_dk/dv so dK/dV are reduced across +// the Q-head group. Planning only. + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(bwd_gqa_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("top_left") + .bias("no") + .lse(true) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("bwd_dot_do_o") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(32) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(16) + .tile_n0(128) + .tile_k0(128) + .tile_n1(16) + .tile_k1(128) + .tile_k0max(32) + .wave(1, 4, 1, 4, 1, 1, 1, 4, 1) + .warp(16, 16, 32, 16, 16, 16, 16, 16, 16) + .padding(true, true, true, true) + .max_seq_len_q(0) + .selection_rank(0), + "gfx950") + .add(FmhaSignature() + .family("bwd_convert_dq") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("top_left") + .bias("no") + .dropout(false) + .dbias(false) + .store_randval(false) + .deterministic(false), + FmhaAlgorithm() + .tile_m0(64) + .tile_n0(128) + .tile_k0(0) + .tile_n1(0) + .tile_k1(0) + .tile_k0max(0) + .padding(true, true, true, true) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 34: FMHA Backward GQA", "nhead_q=8, nhead_k=2 (4:1 ratio)"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead_q", "8", "Query heads"); + args.add_option("--nhead_k", "2", "KV heads (GQA ratio = nhead_q/nhead_k)"); + args.add_option("--seqlen", "128", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead_q = args.get_int("--nhead_q", 8); + const int nhead_k = args.get_int("--nhead_k", 2); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + + print_header("Example 34: FMHA Backward GQA"); + + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + registry.set_name("bwd_gqa_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + std::cout << "\nStep 2: Plan (GQA nhead_q=" << nhead_q << ", nhead_k=" << nhead_k << ")\n"; + FmhaDispatcher dispatcher(®istry); + fmha_bwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.mask_type = mask_enum::mask_top_left; + traits.bias_type = bias_enum::no_bias; + traits.has_dbias = false; + traits.has_dropout = false; + traits.is_store_randval = false; + traits.is_deterministic = false; + + fmha_bwd_args bwd_args{}; + bwd_args.batch = batch; + bwd_args.seqlen_q = seqlen; + bwd_args.seqlen_k = seqlen; + bwd_args.hdim_q = hdim; + bwd_args.hdim_v = hdim; + bwd_args.nhead_q = nhead_q; + bwd_args.nhead_k = nhead_k; + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, bwd_args), gfx_arch)); + std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n"; + + std::cout << "\nStep 3: GQA Backward Head Indexing\n"; + std::cout << " Q heads " << nhead_q << ", KV heads " << nhead_k + << " -> each KV head shared by " << (nhead_q / nhead_k) << " Q heads.\n"; + std::cout << " dK/dV reduced across Q-head group via nhead_stride.\n"; + + print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/cpp/35_generic_mask_fmha.cpp b/dispatcher/examples/fmha/cpp/35_generic_mask_fmha.cpp new file mode 100644 index 0000000000..696ee9e047 --- /dev/null +++ b/dispatcher/examples/fmha/cpp/35_generic_mask_fmha.cpp @@ -0,0 +1,121 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Example 35: FMHA Forward with Generic/Window Mask +// +// Demonstrates forward kernel with generic (window) mask. Uses +// window_size_left and window_size_right: for each query i, attend only to +// keys in [i - left, i + right]. -1 means unbounded. Planning only. + +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +DECL_FMHA_KERNEL_SET(generic_mask_fmha_kernels, + // Forward with generic/window mask + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("generic") // window mask via left/right + .bias("no") + .lse(false) + .dropout(false) + .qscale("no"), + FmhaAlgorithm() + .tile_m0(128) + .tile_n0(128) + .tile_k0(32) + .tile_n1(128) + .tile_k1(32) + .tile_k0max(128) + .wave_m0(4) + .wave_n0(1) + .wave_k0(1) + .wave_m1(4) + .wave_n1(1) + .wave_k1(1) + .warp_m0(32) + .warp_n0(32) + .warp_k0(16) + .warp_m1(32) + .warp_n1(32) + .warp_k1(16) + .pipeline("qr_async") + .padding(true, true, true, true) + .alignments(128, 128) + .selection_rank(0), + "gfx950")); + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 35: FMHA Generic Mask", "Window mask via left/right params"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--batch", "2", "Batch size"); + args.add_option("--nhead", "4", "Number of heads"); + args.add_option("--seqlen", "128", "Sequence length"); + args.add_option("--hdim", "128", "Head dimension"); + args.add_option("--window_left", "64", "Window size left (-1=unbounded)"); + args.add_option("--window_right", "0", "Window size right (-1=unbounded)"); + if(!args.parse(argc, argv)) + return 0; + + const std::string gfx_arch = args.get("--arch", "gfx950"); + const int batch = args.get_int("--batch", 2); + const int nhead = args.get_int("--nhead", 4); + const int seqlen = args.get_int("--seqlen", 128); + const int hdim = args.get_int("--hdim", 128); + const int window_size_left = args.get_int("--window_left", 64); + const int window_size_right = args.get_int("--window_right", 0); + + print_header("Example 35: FMHA Generic Mask"); + + std::cout << "\nStep 1: Register Kernels\n"; + FmhaRegistry registry; + registry.set_name("generic_mask_fmha"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + std::cout << "\nStep 2: Plan\n"; + FmhaDispatcher dispatcher(®istry); + fmha_fwd_traits traits{}; + traits.hdim_q = hdim; + traits.hdim_v = hdim; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::window_generic; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args fwd_args{}; + fwd_args.batch = batch; + fwd_args.seqlen_q = seqlen; + fwd_args.seqlen_k = seqlen; + fwd_args.nhead_q = nhead; + fwd_args.nhead_k = nhead; + fwd_args.hdim_q = hdim; + fwd_args.hdim_v = hdim; + fwd_args.window_size_left = window_size_left; + fwd_args.window_size_right = window_size_right; + + auto plan = dispatcher.plan( + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch)); + std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n"; + + std::cout << "\nStep 3: Window Mask Params\n"; + std::cout << " window_size_left=" << window_size_left + << ", window_size_right=" << window_size_right << "\n"; + std::cout << " Query i attends to keys in [i-left, i+right]. -1 = unbounded.\n"; + + print_separator(); + return 0; +} diff --git a/dispatcher/examples/fmha/python/01_basic_fmha.py b/dispatcher/examples/fmha/python/01_basic_fmha.py new file mode 100644 index 0000000000..eba3bedaf8 --- /dev/null +++ b/dispatcher/examples/fmha/python/01_basic_fmha.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 01: Basic FMHA with Multiple Kernels + +Demonstrates: +1. Building a Registry with multiple kernel configurations +2. Parallel JIT compilation via registry.build() +3. Running each kernel and validating output against CPU reference +4. Comparing performance across kernels + +Usage: + python3 01_basic_fmha.py + python3 01_basic_fmha.py --dtype bf16 + python3 01_basic_fmha.py --size 256 + python3 01_basic_fmha.py --num-kernels 4 + python3 01_basic_fmha.py --workers 4 +""" + +import sys +import time +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelSpec, + FmhaRegistry, + FmhaProblem, + cpu_attention_fwd, + detect_gpu_arch, + spec_to_config, +) + + +# FmhaKernelSpec fields: +# name -- human-readable kernel identifier +# hdim -- head dimension (hdim_q = hdim_v for symmetric attention) +# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous) +# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension) +# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension) +# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension) +# +# spec_to_config() fills in Stage 1 automatically: +# tile_n1 = hdim, tile_k1 = tile_k0, tile_k0max = hdim +# wave/warp use sensible defaults (4x1x1 wave, 32x32x16 warp) +KERNEL_SPECS = [ + # Async pipelines -- different tile_m0 x tile_n0 combos + FmhaKernelSpec( + name="async_128x128_k32", + hdim=128, + pipeline="qr_async", + tile_m0=128, + tile_n0=128, + tile_k0=32, + ), + FmhaKernelSpec( + name="async_128x64_k32", + hdim=128, + pipeline="qr_async", + tile_m0=128, + tile_n0=64, + tile_k0=32, + ), + FmhaKernelSpec( + name="async_64x128_k32", + hdim=128, + pipeline="qr_async", + tile_m0=64, + tile_n0=128, + tile_k0=32, + ), + FmhaKernelSpec( + name="async_64x64_k32", + hdim=128, + pipeline="qr_async", + tile_m0=64, + tile_n0=64, + tile_k0=32, + ), + # Synchronous pipelines + FmhaKernelSpec( + name="sync_128x128_k32", + hdim=128, + pipeline="qr", + tile_m0=128, + tile_n0=128, + tile_k0=32, + ), + FmhaKernelSpec( + name="sync_64x128_k32", + hdim=128, + pipeline="qr", + tile_m0=64, + tile_n0=128, + tile_k0=32, + ), + FmhaKernelSpec( + name="sync_128x64_k32", + hdim=128, + pipeline="qr", + tile_m0=128, + tile_n0=64, + tile_k0=32, + ), + # Different tile_k0 (K dimension of Q*K^T) + FmhaKernelSpec( + name="async_128x128_k64", + hdim=128, + pipeline="qr_async", + tile_m0=128, + tile_n0=128, + tile_k0=64, + ), + FmhaKernelSpec( + name="async_64x128_k64", + hdim=128, + pipeline="qr_async", + tile_m0=64, + tile_n0=128, + tile_k0=64, + ), +] + + +def main(): + parser = argparse.ArgumentParser(description="Basic FMHA with Multiple Kernels") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--size", type=int, default=128, help="Sequence length") + parser.add_argument("--num-kernels", type=int, default=0, help="0 = all") + parser.add_argument( + "--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 01: Basic FMHA with Multiple Kernels") + print("=" * 70) + + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + # Step 1: Build registry + print( + f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}" + ) + print(f"\n {'#':<3} {'Name':<24} {'Tile':<14} {'Pipeline':<12}") + print(" " + "-" * 56) + for i, s in enumerate(specs, 1): + print( + f" {i:<3} {s.name:<24} {s.tile_m0}x{s.tile_n0}x{s.tile_k0:<6} {s.pipeline:<12}" + ) + + reg = FmhaRegistry(name="basic_fmha") + for s in specs: + reg.register_kernel(spec_to_config(s, args.dtype, args.arch)) + + # Step 2: Parallel JIT build via registry.build() + workers = args.workers if args.workers > 0 else None + print( + f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---" + ) + + t0 = time.perf_counter() + setups = reg.build(verbose=False, max_workers=workers) + jit_build_s = time.perf_counter() - t0 + + built = sum(1 for s in setups if s.success) + print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s") + + if built == 0: + print(" ERROR: No kernels built") + return 1 + + # Step 3: Run each kernel and validate + seqlen = args.size + prob = FmhaProblem( + batch=2, + nhead_q=8, + nhead_k=8, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=128, + hdim_v=128, + ) + + print( + f"\n--- Running Kernels (B={prob.batch} H={prob.nhead_q} S={seqlen} D={prob.hdim_q}) ---" + ) + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + print( + f"\n {'#':<3} {'Name':<24} {'Pipeline':<12} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}" + ) + print(" " + "-" * 80) + + results = [] + for i, (spec, setup) in enumerate(zip(specs, setups), 1): + if not setup.success or setup.runner is None: + print( + f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}" + ) + results.append((spec.name, False, 0.0, 0.0, 0.0)) + continue + + res = setup.runner.run(Q, K, V, prob) + if not res.success: + print( + f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}" + ) + results.append((spec.name, False, 0.0, 0.0, 0.0)) + continue + + max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max()) + ok = max_err < 1e-2 + tag = "PASS" if ok else "FAIL" + print( + f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}" + ) + results.append((spec.name, ok, res.time_ms, res.tflops, max_err)) + setup.runner.cleanup() + + # Step 4: Summary + passed = sum(1 for r in results if r[1]) + failed = len(results) - passed + valid = [r for r in results if r[1]] + + print("\n" + "=" * 70) + print(f" Results: {passed}/{len(results)} passed") + print( + f" Problem: B={prob.batch} H={prob.nhead_q} S={seqlen} D={prob.hdim_q}, dtype={args.dtype}" + ) + print(f" JIT time: {jit_build_s:.1f} s (parallel)") + if valid: + best = max(valid, key=lambda x: x[3]) + print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)") + print(f" Status: {'PASS' if failed == 0 else 'FAIL'}") + print("=" * 70) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/02_multi_shape.py b/dispatcher/examples/fmha/python/02_multi_shape.py new file mode 100644 index 0000000000..5b6a31959a --- /dev/null +++ b/dispatcher/examples/fmha/python/02_multi_shape.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 02: Multi-Shape FMHA + +Runs FMHA forward with a single kernel across multiple problem shapes +(varying batch, sequence length, and head count). + +Usage: + python3 02_multi_shape.py + python3 02_multi_shape.py --help + python3 02_multi_shape.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelSpec, + FmhaProblem, + detect_gpu_arch, + setup_fmha_dispatcher, + spec_to_config, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Multi-Shape FMHA Example - runs multiple shapes", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 02_multi_shape.py # Default FP16 + python3 02_multi_shape.py --dtype bf16 # BF16 FMHA + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 02: Multi-Shape FMHA") + print("=" * 70) + + # Step 1: Setup dispatcher + print("\nStep 1: Setup Dispatcher") + + # FmhaKernelSpec fields: + # name -- human-readable kernel identifier + # hdim -- head dimension (hdim_q = hdim_v) + # pipeline -- "qr_async" (async prefetch) or "qr" (synchronous) + # tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension) + # tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension) + # tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension) + spec = FmhaKernelSpec(name="multi_shape", hdim=128, pipeline="qr_async") + config = spec_to_config(spec, dtype=args.dtype, arch=args.arch) + + setup = setup_fmha_dispatcher(config, verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + runner = setup.runner + print(f" Library: {setup.library_path}") + print(f" Build: {setup.build_time_s:.1f} s") + + # Step 2: Run batch of different shapes + print("\nStep 2: Run Shapes") + + shapes = [ + # (batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim) + (1, 4, 4, 64, 64, 128), + (2, 8, 8, 128, 128, 128), + (4, 8, 8, 128, 128, 128), + (1, 16, 16, 256, 256, 128), + (2, 8, 8, 256, 256, 128), + (1, 8, 8, 512, 512, 128), + (2, 4, 4, 512, 512, 128), + (1, 8, 8, 1024, 1024, 128), + ] + + print(f"\n {'#':<3} {'Shape':<36} {'Time(ms)':>10} {'TFLOPS':>10} {'Status':>8}") + print(" " + "-" * 70) + + total_ops = 0 + total_time = 0.0 + + for idx, (b, hq, hk, sq, sk, d) in enumerate(shapes, 1): + prob = FmhaProblem( + batch=b, + nhead_q=hq, + nhead_k=hk, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=d, + hdim_v=d, + ) + shape_str = f"B{b}_Hq{hq}_Hk{hk}_S{sq}x{sk}_D{d}" + + np.random.seed(42 + idx) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + + result = runner.run(Q, K, V, prob) + + if result.success: + total_ops += prob.num_ops + total_time += result.time_ms + print( + f" {idx:<3} {shape_str:<36} {result.time_ms:>10.4f} {result.tflops:>10.2f} {'OK':>8}" + ) + else: + print(f" {idx:<3} {shape_str:<36} {'N/A':>10} {'N/A':>10} {'Error':>8}") + + print(" " + "-" * 70) + + if total_time > 0: + avg_tflops = (total_ops / 1e12) / (total_time / 1000) + print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS") + + runner.cleanup() + + print("\n" + "=" * 70) + print("Multi-Shape FMHA complete!") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/03_benchmark.py b/dispatcher/examples/fmha/python/03_benchmark.py new file mode 100644 index 0000000000..59fdc76f56 --- /dev/null +++ b/dispatcher/examples/fmha/python/03_benchmark.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 03: FMHA Benchmark + +Performance benchmarking with warmup and repeated iterations across +multiple (batch, sequence length) configurations. + +Usage: + python3 03_benchmark.py + python3 03_benchmark.py --help + python3 03_benchmark.py --warmup 5 --repeat 20 + python3 03_benchmark.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelSpec, + FmhaProblem, + detect_gpu_arch, + setup_fmha_dispatcher, + spec_to_config, +) + + +def main(): + parser = argparse.ArgumentParser( + description="FMHA Benchmark Example - performance testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 03_benchmark.py # Default benchmark suite + python3 03_benchmark.py --warmup 5 # More warmup iterations + python3 03_benchmark.py --repeat 20 # More benchmark iterations + """, + ) + parser.add_argument( + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", + ) + parser.add_argument( + "--warmup", type=int, default=3, help="Warmup iterations (default: 3)" + ) + parser.add_argument( + "--repeat", type=int, default=10, help="Benchmark iterations (default: 10)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 03: FMHA Benchmark") + print("=" * 70) + + # Step 1: Setup dispatcher with compute-optimized config + print("\nStep 1: Setup Dispatcher") + + # FmhaKernelSpec fields: + # name -- human-readable kernel identifier + # hdim -- head dimension (hdim_q = hdim_v) + # pipeline -- "qr_async" (async prefetch) or "qr" (synchronous) + # tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension) + # tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension) + # tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension) + spec = FmhaKernelSpec(name="benchmark", hdim=128, pipeline="qr_async") + config = spec_to_config(spec, dtype="fp16", arch=args.arch) + + setup = setup_fmha_dispatcher(config, verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + runner = setup.runner + print(f" Library: {setup.library_path}") + print(f" Build: {setup.build_time_s:.1f} s") + + # Step 2: Benchmark + print("\nStep 2: Benchmark") + + bench_configs = [ + (1, 128), + (1, 256), + (1, 512), + (1, 1024), + (1, 2048), + (2, 128), + (2, 256), + (2, 512), + (2, 1024), + (4, 128), + (4, 256), + (4, 512), + (8, 128), + (8, 256), + ] + + print(f" Warmup: {args.warmup}, Repeat: {args.repeat}\n") + + print( + f" {'Batch':>5} {'SeqLen':>7} | {'Min(ms)':>10} {'Avg(ms)':>10} {'Max(ms)':>10} | {'TFLOPS':>10}" + ) + print(" " + "-" * 62) + + all_tflops = [] + + for batch, seqlen in bench_configs: + prob = FmhaProblem( + batch=batch, + nhead_q=8, + nhead_k=8, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=128, + hdim_v=128, + ) + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + + for _ in range(args.warmup): + runner.run(Q, K, V, prob) + + times = [] + for _ in range(args.repeat): + result = runner.run(Q, K, V, prob) + if result.success: + times.append(result.time_ms) + + if times: + min_time = min(times) + avg_time = sum(times) / len(times) + max_time = max(times) + tflops = prob.num_ops / (avg_time * 1e-3) / 1e12 + all_tflops.append(tflops) + print( + f" {batch:>5} {seqlen:>7} | {min_time:>10.4f} {avg_time:>10.4f} {max_time:>10.4f} | {tflops:>10.2f}" + ) + else: + print( + f" {batch:>5} {seqlen:>7} | {'---':>10} {'---':>10} {'---':>10} | {'FAIL':>10}" + ) + + runner.cleanup() + + # Summary + print("\n" + "=" * 70) + print("Summary") + print("=" * 70) + + if all_tflops: + print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS") + print(f" Peak: {max(all_tflops):.2f} TFLOPS") + + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/04_validation.py b/dispatcher/examples/fmha/python/04_validation.py new file mode 100644 index 0000000000..aeb9665349 --- /dev/null +++ b/dispatcher/examples/fmha/python/04_validation.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 04: FMHA Validation + +Validates GPU FMHA against CPU reference across multiple test cases +including standard shapes, GQA ratios, and edge cases. + +Usage: + python3 04_validation.py + python3 04_validation.py --help + python3 04_validation.py --dtype bf16 + python3 04_validation.py --rtol 1e-2 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelSpec, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, + spec_to_config, +) + + +def main(): + parser = argparse.ArgumentParser( + description="FMHA Validation Example - validates GPU results against CPU", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 04_validation.py # Default FP16 validation + python3 04_validation.py --dtype bf16 # BF16 validation + python3 04_validation.py --rtol 1e-2 # Relaxed tolerance + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--rtol", type=float, default=1e-2, help="Relative tolerance (default: 1e-2)" + ) + parser.add_argument( + "--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)" + ) + parser.add_argument( + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 04: FMHA Validation") + print("=" * 70) + + # Step 1: Setup dispatcher + print("\nStep 1: Setup Dispatcher") + + # FmhaKernelSpec fields: + # name -- human-readable kernel identifier + # hdim -- head dimension (hdim_q = hdim_v) + # pipeline -- "qr_async" (async prefetch) or "qr" (synchronous) + # tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension) + # tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension) + # tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension) + spec = FmhaKernelSpec(name="validation", hdim=128, pipeline="qr_async") + config = spec_to_config(spec, dtype=args.dtype, arch=args.arch) + + setup = setup_fmha_dispatcher(config, verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + runner = setup.runner + print(f" Library: {setup.library_path}") + print(f" Build: {setup.build_time_s:.1f} s") + + # Step 2: Run validation tests + print("\nStep 2: Validation Tests") + + validator = FmhaValidator(rtol=args.rtol, atol=args.atol) + + # (name, batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim) + test_cases = [ + ("Small", 1, 4, 4, 64, 64, 128), + ("Medium", 2, 8, 8, 128, 128, 128), + ("Large", 1, 8, 8, 256, 256, 128), + ("Long-seq", 1, 4, 4, 512, 512, 128), + ("Non-square", 2, 4, 4, 64, 256, 128), + ("GQA-2:1", 2, 8, 4, 128, 128, 128), + ("GQA-4:1", 1, 16, 4, 128, 128, 128), + ("GQA-8:1", 1, 16, 2, 64, 64, 128), + ("Single-query", 1, 4, 4, 1, 128, 128), + ("Batched", 4, 8, 8, 128, 128, 128), + ] + + passed = 0 + failed = 0 + + print(f"\n {'#':<3} {'Test':<14} {'Shape':<30} {'MaxErr':>10} {'Status':>8}") + print(" " + "-" * 70) + + for idx, (name, b, hq, hk, sq, sk, d) in enumerate(test_cases, 1): + prob = FmhaProblem( + batch=b, + nhead_q=hq, + nhead_k=hk, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=d, + hdim_v=d, + ) + shape_str = f"B{b}_Hq{hq}_Hk{hk}_S{sq}x{sk}" + + np.random.seed(42 + idx) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + + result = runner.run(Q, K, V, prob) + if not result.success: + print( + f" {idx:<3} {name:<14} {shape_str:<30} {'GPU Err':>10} {'FAILED':>8}" + ) + failed += 1 + continue + + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + is_valid, max_abs, _ = validator.check(result.output, O_ref) + + if is_valid: + print( + f" {idx:<3} {name:<14} {shape_str:<30} {max_abs:>10.2e} {'PASSED':>8}" + ) + passed += 1 + else: + print( + f" {idx:<3} {name:<14} {shape_str:<30} {max_abs:>10.2e} {'FAILED':>8}" + ) + failed += 1 + + runner.cleanup() + + # Summary + print("\n" + "=" * 70) + total = passed + failed + print(f" Results: {passed}/{total} passed") + print(f" Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}") + print(f" Status: {'PASS' if failed == 0 else 'FAIL'}") + print("=" * 70) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/05_numpy_integration.py b/dispatcher/examples/fmha/python/05_numpy_integration.py new file mode 100644 index 0000000000..0303b2d5c7 --- /dev/null +++ b/dispatcher/examples/fmha/python/05_numpy_integration.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 05: NumPy Integration + +Shows how to create a GPU-accelerated attention wrapper that works +seamlessly with NumPy arrays, hiding all HIP memory management. + +Usage: + python3 05_numpy_integration.py + python3 05_numpy_integration.py --help + python3 05_numpy_integration.py --seqlen 256 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def fmha_matmul( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float = None, + runner=None, +) -> np.ndarray: + """GPU-accelerated scaled dot-product attention via FMHA dispatcher. + + Args: + Q: [batch, nhead_q, seqlen_q, hdim_q] float16/float32 + K: [batch, nhead_k, seqlen_k, hdim_q] float16/float32 + V: [batch, nhead_k, seqlen_k, hdim_v] float16/float32 + scale: softmax scale (default: 1/sqrt(hdim_q)) + runner: reuse an existing runner from setup_fmha_dispatcher + + Returns: + O: [batch, nhead_q, seqlen_q, hdim_v] float16 + """ + batch, nhead_q, seqlen_q, hdim_q = Q.shape + _, nhead_k, seqlen_k, hdim_v = V.shape + + prob = FmhaProblem( + batch=batch, + nhead_q=nhead_q, + nhead_k=nhead_k, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + hdim_q=hdim_q, + hdim_v=hdim_v, + ) + + result = runner.run( + Q.astype(np.float16), K.astype(np.float16), V.astype(np.float16), prob + ) + if not result.success: + raise RuntimeError(f"GPU FMHA failed: {result.error}") + return result.output + + +def main(): + parser = argparse.ArgumentParser( + description="NumPy Integration Example - GPU-accelerated attention wrapper", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 05_numpy_integration.py # Default + python3 05_numpy_integration.py --seqlen 256 # Longer sequences + """, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=4) + parser.add_argument("--seqlen", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument("--rtol", type=float, default=1e-2) + parser.add_argument("--atol", type=float, default=1e-2) + args = parser.parse_args() + + print("=" * 70) + print("Example 05: NumPy Integration") + print("=" * 70) + + # Step 1: JIT-compile FMHA kernel + print("\nStep 1: JIT-Compile FMHA Dispatcher") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + return 1 + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + print(f" Arch: {args.arch}") + + np_dtype = np.float16 + + # Step 2: Demo -- simple attention call + print("\n" + "=" * 70) + print("Step 2: Simple Attention Call") + print("=" * 70) + + np.random.seed(42) + Q = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype( + np_dtype + ) + K = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype( + np_dtype + ) + V = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype( + np_dtype + ) + + out = fmha_matmul(Q, K, V, runner=runner) + print(f" Q: {Q.shape} -> O: {out.shape}") + print(f" Output range: [{out.min():.4f}, {out.max():.4f}]") + print(f" Output sum: {out.sum():.4f}") + + # Step 3: Validate against CPU reference + print("\n" + "=" * 70) + print("Step 3: Validate Against CPU Reference") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + diff = np.abs(out.astype(np.float32) - O_ref) + max_abs = float(diff.max()) + max_rel = float((diff / (np.abs(O_ref) + 1e-6)).max()) + match = np.allclose(out.astype(np.float32), O_ref, atol=args.atol, rtol=args.rtol) + + print(f" Max abs error: {max_abs:.6e}") + print(f" Max rel error: {max_rel:.6e}") + print(f" Match: {match}") + + # Step 4: Demo -- multi-head attention with GQA + print("\n" + "=" * 70) + print("Step 4: GQA Attention (nhead_q=8, nhead_k=2)") + print("=" * 70) + + nhead_q, nhead_k = 8, 2 + Q_gqa = (np.random.randn(args.batch, nhead_q, args.seqlen, args.hdim) * 0.5).astype( + np_dtype + ) + K_gqa = (np.random.randn(args.batch, nhead_k, args.seqlen, args.hdim) * 0.5).astype( + np_dtype + ) + V_gqa = (np.random.randn(args.batch, nhead_k, args.seqlen, args.hdim) * 0.5).astype( + np_dtype + ) + + O_gqa = fmha_matmul(Q_gqa, K_gqa, V_gqa, runner=runner) + + prob_gqa = FmhaProblem( + batch=args.batch, + nhead_q=nhead_q, + nhead_k=nhead_k, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + O_gqa_ref = cpu_attention_fwd( + Q_gqa.astype(np.float32), + K_gqa.astype(np.float32), + V_gqa.astype(np.float32), + prob_gqa.scale, + ) + gqa_match = np.allclose( + O_gqa.astype(np.float32), O_gqa_ref, atol=args.atol, rtol=args.rtol + ) + + print(f" Q: {Q_gqa.shape}, K: {K_gqa.shape}, V: {V_gqa.shape}") + print(f" O: {O_gqa.shape}") + print(f" Match: {gqa_match}") + + # Summary + print("\n" + "=" * 70) + print("NumPy Integration Pattern:") + print("=" * 70) + print(" 1. setup = setup_fmha_dispatcher(config)") + print(" 2. O = fmha_matmul(Q, K, V, runner=setup.runner)") + print("=" * 70) + + return 0 if match and gqa_match else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/06_json_export.py b/dispatcher/examples/fmha/python/06_json_export.py new file mode 100644 index 0000000000..b90b43cdbc --- /dev/null +++ b/dispatcher/examples/fmha/python/06_json_export.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 06: JSON Export + +Builds an FMHA kernel via setup_fmha_dispatcher, then exports the +registry configuration to JSON for inspection or reuse. + +Usage: + python3 06_json_export.py + python3 06_json_export.py --help + python3 06_json_export.py --output fmha_kernels.json +""" + +import sys +import json +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from fmha_utils import ( + FmhaKernelConfig, + setup_fmha_dispatcher, + detect_gpu_arch, +) + + +def main(): + parser = argparse.ArgumentParser( + description="JSON Export Example - export FMHA registry to JSON", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 06_json_export.py # Default output + python3 06_json_export.py --output fmha_kernels.json # Custom file + """, + ) + parser.add_argument( + "--output", + "-o", + default="fmha_kernels.json", + help="Output JSON file (default: fmha_kernels.json)", + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + args = parser.parse_args() + + print("=" * 70) + print("Example 06: JSON Export") + print("=" * 70) + + # Step 1: Define FMHA kernel configurations + print("\nStep 1: Define Kernel Configurations") + + configs = [ + FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr_async", + # Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q + tile_m0=128, + tile_n0=128, + tile_k0=32, + # Stage 1 (Attn*V): hdim_v x seqlen_k x alignment + tile_n1=128, + tile_k1=32, + tile_k0max=128, + # Wave config per stage + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + # Warp tile per stage + warp_m0=32, + warp_n0=32, + warp_k0=16, + warp_m1=32, + warp_n1=32, + warp_k1=16, + gfx_arch=args.arch, + ), + FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr", + tile_m0=64, + tile_n0=128, + tile_k0=32, + tile_n1=128, + tile_k1=32, + tile_k0max=128, + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + warp_m0=16, + warp_n0=16, + warp_k0=32, + warp_m1=16, + warp_n1=16, + warp_k1=16, + pad_s=False, + pad_sk=False, + pad_d=True, + pad_dv=True, + gfx_arch=args.arch, + ), + FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=64, + hdim_v=64, + pipeline="qr_async", + tile_m0=128, + tile_n0=64, + tile_k0=32, + tile_n1=64, + tile_k1=32, + tile_k0max=64, + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + warp_m0=32, + warp_n0=32, + warp_k0=16, + warp_m1=32, + warp_n1=32, + warp_k1=16, + gfx_arch=args.arch, + ), + ] + + for i, cfg in enumerate(configs, 1): + print(f" [{i}] {cfg.name}: pipeline={cfg.pipeline}, hdim={cfg.hdim_q}") + + # Step 2: Build via setup_fmha_dispatcher + print("\n" + "=" * 70) + print("Step 2: Build Kernel (JIT)") + print("=" * 70) + + setup = setup_fmha_dispatcher(configs[0], verbose=True) + if setup.success: + print(f" Built: {setup.library_path}") + print(f" Time: {setup.build_time_s:.1f} s") + else: + print(f" Build skipped/failed: {setup.error}") + print(" (Proceeding with config export only)") + + # Step 3: Export to JSON + print("\n" + "=" * 70) + print("Step 3: Export to JSON") + print("=" * 70) + + export_data = { + "registry": "fmha_export", + "arch": args.arch, + "kernel_count": len(configs), + "kernels": [], + } + + for cfg in configs: + kernel_info = { + "name": cfg.name, + "family": cfg.family, + "data_type": cfg.data_type, + "hdim_q": cfg.hdim_q, + "hdim_v": cfg.hdim_v, + "pipeline": cfg.pipeline, + "tile": list(cfg.tile), + "wave": list(cfg.wave), + "warp": list(cfg.warp), + "padding": list(cfg.padding), + "mode": cfg.mode, + "target": cfg.gfx_arch, + "codegen_json": json.loads(cfg.to_codegen_json()), + } + export_data["kernels"].append(kernel_info) + + json_str = json.dumps(export_data, indent=2) + + with open(args.output, "w") as f: + f.write(json_str) + print(f" Saved to: {args.output}") + + file_size = Path(args.output).stat().st_size + print(f" File size: {file_size:,} bytes") + print(f" Kernel count: {len(configs)}") + + # Step 4: Preview + print("\n" + "=" * 70) + print("Step 4: JSON Preview") + print("=" * 70) + preview = json_str[:500] + if len(json_str) > 500: + preview += "\n ..." + print(preview) + + print("\n" + "=" * 70) + print("JSON Export complete!") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/07_stress_test.py b/dispatcher/examples/fmha/python/07_stress_test.py new file mode 100644 index 0000000000..092c2b7e73 --- /dev/null +++ b/dispatcher/examples/fmha/python/07_stress_test.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 07: Stress Test - Multiple FMHA Kernels with Validation + +Generates many FmhaKernelSpec configurations across pipelines, head +dimensions, and data types, registers them in an FmhaRegistry, builds +all in parallel, and validates each against a CPU reference. + +Usage: + python3 07_stress_test.py + python3 07_stress_test.py --help + python3 07_stress_test.py --num-kernels 4 + python3 07_stress_test.py --workers 8 +""" + +import sys +import time +import argparse +from pathlib import Path +from typing import List + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelSpec, + FmhaProblem, + FmhaRegistry, + FmhaValidator, + cpu_attention_fwd, + spec_to_config, + detect_gpu_arch, +) + + +# FmhaKernelSpec fields: +# name -- human-readable kernel identifier +# hdim -- head dimension (hdim_q = hdim_v) +# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous) +# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension) +# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension) +# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension) +KERNEL_SPECS: List[FmhaKernelSpec] = [ + # qr_async pipeline -- various tile sizes + FmhaKernelSpec( + name="qr_async_h128_t128", + hdim=128, + pipeline="qr_async", + tile_m0=128, + tile_n0=128, + tile_k0=32, + ), + FmhaKernelSpec( + name="qr_async_h128_t64", + hdim=128, + pipeline="qr_async", + tile_m0=64, + tile_n0=128, + tile_k0=32, + ), + FmhaKernelSpec( + name="qr_async_h64_t128", + hdim=64, + pipeline="qr_async", + tile_m0=128, + tile_n0=64, + tile_k0=32, + ), + FmhaKernelSpec( + name="qr_async_h64_t64", + hdim=64, + pipeline="qr_async", + tile_m0=64, + tile_n0=64, + tile_k0=32, + ), + # qr pipeline -- various tile sizes + FmhaKernelSpec( + name="qr_h128_t128", + hdim=128, + pipeline="qr", + tile_m0=128, + tile_n0=128, + tile_k0=32, + ), + FmhaKernelSpec( + name="qr_h128_t64", hdim=128, pipeline="qr", tile_m0=64, tile_n0=128, tile_k0=32 + ), + FmhaKernelSpec( + name="qr_h64_t128", hdim=64, pipeline="qr", tile_m0=128, tile_n0=64, tile_k0=32 + ), + FmhaKernelSpec( + name="qr_h64_t64", hdim=64, pipeline="qr", tile_m0=64, tile_n0=64, tile_k0=32 + ), +] + + +def print_spec_table(specs: List[FmhaKernelSpec]): + print( + f"\n {'#':<3} {'Name':<25} {'Pipeline':<12} {'Hdim':>5} " + f"{'TileM':>6} {'TileN':>6} {'TileK':>6}" + ) + print(" " + "-" * 70) + for i, s in enumerate(specs, 1): + print( + f" {i:<3} {s.name:<25} {s.pipeline:<12} {s.hdim:>5} " + f"{s.tile_m0:>6} {s.tile_n0:>6} {s.tile_k0:>6}" + ) + print(" " + "-" * 70) + + +def main(): + parser = argparse.ArgumentParser( + description="FMHA Stress Test - multiple kernels with validation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 07_stress_test.py # Test all kernels + python3 07_stress_test.py --num-kernels 4 # First 4 only + python3 07_stress_test.py --workers 8 # 8 parallel compile workers + """, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument( + "--num-kernels", type=int, default=0, help="Number of kernels to test (0 = all)" + ) + parser.add_argument( + "--workers", type=int, default=0, help="Max parallel build workers (0 = auto)" + ) + parser.add_argument("--rtol", type=float, default=1e-2) + parser.add_argument("--atol", type=float, default=1e-2) + args = parser.parse_args() + + print("=" * 70) + print("Example 07: FMHA Stress Test - Multiple Kernels") + print("=" * 70) + + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + print(f"\n Arch: {args.arch}") + print(f" Kernels: {len(specs)}") + print_spec_table(specs) + + # Step 1: Register all in FmhaRegistry and build + print("\n" + "=" * 70) + print(" JIT BUILD") + print("=" * 70) + + reg = FmhaRegistry("stress_test") + for spec in specs: + cfg = spec_to_config(spec, dtype="fp16", arch=args.arch) + reg.register_kernel(cfg) + + workers = args.workers if args.workers > 0 else None + print(f"\n Building {len(reg)} kernels (workers={workers or 'auto'}) ...") + + t0 = time.perf_counter() + build_results = reg.build(verbose=False, max_workers=workers) + build_time = time.perf_counter() - t0 + + built = sum(1 for r in build_results if r.success) + print(f" Built: {built}/{len(specs)} in {build_time:.1f} s") + + for i, r in enumerate(build_results, 1): + tag = "OK" if r.success else f"FAIL: {r.error[:50]}" + name = r.config.name if r.config else f"kernel_{i}" + print(f" [{i}] {name}: {tag}") + + if built == 0: + print("\n No kernels built -- aborting") + return 1 + + # Step 2: Validate each built kernel + print("\n" + "=" * 70) + print(" VALIDATION") + print("=" * 70) + + prob = FmhaProblem( + batch=2, nhead_q=4, nhead_k=4, seqlen_q=64, seqlen_k=64, hdim_q=128, hdim_v=128 + ) + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16) + O_ref = cpu_attention_fwd( + Q.astype(np.float32), K.astype(np.float32), V.astype(np.float32), prob.scale + ) + + validator = FmhaValidator(rtol=args.rtol, atol=args.atol) + + print( + f"\n Problem: B={prob.batch} Hq={prob.nhead_q} Sq={prob.seqlen_q} D={prob.hdim_q}" + ) + print(f"\n {'#':<3} {'Name':<35} {'Time':>8} {'MaxErr':>10} {'Status':<6}") + print(" " + "-" * 66) + + total_pass = 0 + total_fail = 0 + + for i, r in enumerate(build_results, 1): + name = r.config.name if r.config else f"kernel_{i}" + + if not r.success or r.runner is None: + print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'SKIP':<6}") + continue + + hdim = r.config.hdim_q if r.config else 128 + if hdim != prob.hdim_q: + print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'SKIP':<6}") + continue + + res = r.runner.run(Q, K, V, prob) + if not res.success: + print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'FAIL':<6}") + total_fail += 1 + continue + + ok, max_abs, _ = validator.check(res.output, O_ref) + tag = "PASS" if ok else "FAIL" + print(f" {i:<3} {name:<35} {res.time_ms:>7.4f}ms {max_abs:>10.2e} {tag:<6}") + + if ok: + total_pass += 1 + else: + total_fail += 1 + + r.runner.cleanup() + + # Summary + print("\n" + "=" * 70) + print(" SUMMARY") + print("=" * 70) + print(f"\n Total: {len(specs)}") + print(f" Built: {built}") + print(f" Passed: {total_pass}") + print(f" Failed: {total_fail}") + print(f" Build time: {build_time:.1f} s") + print(f" Tolerance: rtol={args.rtol}, atol={args.atol}") + + if total_fail == 0 and total_pass > 0: + print("\n *** ALL VALIDATED KERNELS PASSED ***") + elif total_fail > 0: + print(f"\n *** {total_fail} KERNELS FAILED ***") + + print("=" * 70) + + return 0 if total_fail == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/08_heuristics.py b/dispatcher/examples/fmha/python/08_heuristics.py new file mode 100644 index 0000000000..9d01347856 --- /dev/null +++ b/dispatcher/examples/fmha/python/08_heuristics.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 08: Kernel Selection Heuristics + +Demonstrates how to build multiple FMHA kernels with different tile +sizes and select the best kernel for a given problem. Shows that +smaller tiles tend to be better for short sequences while larger tiles +are better for long sequences. + +Usage: + python3 08_heuristics.py + python3 08_heuristics.py --help + python3 08_heuristics.py --arch gfx950 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaRegistry, + detect_gpu_arch, +) + + +@dataclass +class TileProfile: + """A kernel profile tagged with a human-readable label.""" + + label: str + config: FmhaKernelConfig + category: str # "small", "medium", "large" + + +def build_tile_profiles(arch: str) -> List[TileProfile]: + """Create kernel configs with varying tile sizes.""" + return [ + TileProfile( + label="small_64x64", + category="small", + config=FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr_async", + # Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q + tile_m0=64, + tile_n0=64, + tile_k0=32, + # Stage 1 (Attn*V): hdim_v x seqlen_k x alignment + tile_n1=128, + tile_k1=32, + tile_k0max=128, + # Wave config per stage + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + # Warp tile per stage + warp_m0=16, + warp_n0=16, + warp_k0=16, + warp_m1=16, + warp_n1=16, + warp_k1=16, + gfx_arch=arch, + ), + ), + TileProfile( + label="medium_128x128", + category="medium", + config=FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr_async", + tile_m0=128, + tile_n0=128, + tile_k0=32, + tile_n1=128, + tile_k1=32, + tile_k0max=128, + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + warp_m0=32, + warp_n0=32, + warp_k0=16, + warp_m1=32, + warp_n1=32, + warp_k1=16, + gfx_arch=arch, + ), + ), + TileProfile( + label="large_128x256", + category="large", + config=FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr_async", + tile_m0=128, + tile_n0=256, + tile_k0=32, + tile_n1=128, + tile_k1=32, + tile_k0max=128, + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + warp_m0=32, + warp_n0=32, + warp_k0=16, + warp_m1=32, + warp_n1=32, + warp_k1=16, + gfx_arch=arch, + ), + ), + TileProfile( + label="medium_qr_128x128", + category="medium", + config=FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr", + tile_m0=128, + tile_n0=128, + tile_k0=32, + tile_n1=128, + tile_k1=32, + tile_k0max=128, + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + warp_m0=32, + warp_n0=32, + warp_k0=16, + warp_m1=32, + warp_n1=32, + warp_k1=16, + pad_s=False, + pad_sk=False, + pad_d=True, + pad_dv=True, + gfx_arch=arch, + ), + ), + ] + + +def select_kernel_heuristic(seqlen: int, profiles: List[TileProfile]) -> TileProfile: + """Simple heuristic: pick tile size category based on sequence length.""" + if seqlen <= 64: + target = "small" + elif seqlen <= 256: + target = "medium" + else: + target = "large" + + candidates = [p for p in profiles if p.category == target] + if not candidates: + candidates = profiles + return candidates[0] + + +def main(): + parser = argparse.ArgumentParser( + description="FMHA Heuristics - kernel selection by problem size", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 08_heuristics.py + python3 08_heuristics.py --arch gfx950 + """, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + args = parser.parse_args() + + print("=" * 70) + print("Example 08: Kernel Selection Heuristics") + print("=" * 70) + + # Step 1: Build kernel pool + print("\nStep 1: Build Kernel Pool") + profiles = build_tile_profiles(args.arch) + + reg = FmhaRegistry("heuristic_pool") + for p in profiles: + reg.register_kernel(p.config) + + print(f" Profiles: {len(profiles)}") + for i, p in enumerate(profiles, 1): + tile_str = f"{p.config.tile[0]}x{p.config.tile[1]}" + print( + f" [{i}] {p.label:<25} tile={tile_str:<10} pipeline={p.config.pipeline}" + ) + + print("\n Building kernels ...") + build_results = reg.build(verbose=False) + built = sum(1 for r in build_results if r.success) + print(f" Built: {built}/{len(profiles)}") + + for i, r in enumerate(build_results): + tag = "OK" if r.success else f"FAIL: {r.error[:40]}" + print(f" [{i + 1}] {profiles[i].label}: {tag}") + + if built == 0: + print(" No kernels built -- aborting") + return 1 + + # Step 2: Run each kernel on multiple sequence lengths + print("\n" + "=" * 70) + print("Step 2: Benchmark Across Sequence Lengths") + print("=" * 70) + + test_seqlens = [32, 64, 128, 256, 512] + + header = f" {'SeqLen':>7}" + for p in profiles: + header += f" | {p.label:>18}" + header += " | {'Best':>18}" + print(f"\n {'SeqLen':>7}", end="") + for p in profiles: + print(f" | {p.label:>18}", end="") + print(f" | {'Best':>18}") + print(" " + "-" * (10 + 21 * len(profiles) + 22)) + + for seqlen in test_seqlens: + prob = FmhaProblem( + batch=2, + nhead_q=8, + nhead_k=8, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=128, + hdim_v=128, + ) + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16) + + row = f" {seqlen:>7}" + best_tflops = 0.0 + best_label = "---" + + for j, (p, r) in enumerate(zip(profiles, build_results)): + if not r.success or r.runner is None: + row += f" | {'N/A':>18}" + continue + + res = r.runner.run(Q, K, V, prob) + if res.success: + cell = f"{res.tflops:.2f} TFLOPS" + row += f" | {cell:>18}" + if res.tflops > best_tflops: + best_tflops = res.tflops + best_label = p.label + else: + row += f" | {'ERR':>18}" + + row += f" | {best_label:>18}" + print(row) + + # Step 3: Demonstrate heuristic selection + print("\n" + "=" * 70) + print("Step 3: Heuristic Selection Demo") + print("=" * 70) + + print(f"\n {'SeqLen':>7} {'Selected':>25} {'TFLOPS':>10} {'Status':<6}") + print(" " + "-" * 55) + + for seqlen in test_seqlens: + selected = select_kernel_heuristic(seqlen, profiles) + idx = profiles.index(selected) + r = build_results[idx] + + if not r.success or r.runner is None: + print(f" {seqlen:>7} {selected.label:>25} {'---':>10} {'SKIP':<6}") + continue + + prob = FmhaProblem( + batch=2, + nhead_q=8, + nhead_k=8, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=128, + hdim_v=128, + ) + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16) + + res = r.runner.run(Q, K, V, prob) + if res.success: + print(f" {seqlen:>7} {selected.label:>25} {res.tflops:>10.2f} {'OK':<6}") + else: + print(f" {seqlen:>7} {selected.label:>25} {'---':>10} {'FAIL':<6}") + + # Cleanup + for r in build_results: + if r.runner: + r.runner.cleanup() + + print("\n" + "=" * 70) + print("Heuristic Insight:") + print(" - Small tiles: low overhead for short sequences") + print(" - Large tiles: high throughput for long sequences") + print(" - Pipeline choice also matters (qr vs qr_async)") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/09_multi_registry.py b/dispatcher/examples/fmha/python/09_multi_registry.py new file mode 100644 index 0000000000..33ec92ab50 --- /dev/null +++ b/dispatcher/examples/fmha/python/09_multi_registry.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 09: Multiple Registries + +Creates separate FmhaRegistry instances for different optimization +targets (latency vs throughput), builds both, runs the same problem +through each, and compares results. + +Usage: + python3 09_multi_registry.py + python3 09_multi_registry.py --help + python3 09_multi_registry.py --arch gfx950 +""" + +import sys +import time +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaRegistry, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, +) + + +def make_latency_config(arch: str) -> FmhaKernelConfig: + """Latency-optimized: smaller tiles, lower launch overhead.""" + return FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr", + # Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q + tile_m0=64, + tile_n0=128, + tile_k0=32, + # Stage 1 (Attn*V): hdim_v x seqlen_k x alignment + tile_n1=128, + tile_k1=32, + tile_k0max=128, + # Wave config per stage + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + # Warp tile per stage + warp_m0=16, + warp_n0=16, + warp_k0=32, + warp_m1=16, + warp_n1=16, + warp_k1=16, + pad_s=False, + pad_sk=False, + pad_d=True, + pad_dv=True, + gfx_arch=arch, + ) + + +def make_throughput_config(arch: str) -> FmhaKernelConfig: + """Throughput-optimized: larger tiles, async pipeline.""" + return FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr_async", + tile_m0=128, + tile_n0=128, + tile_k0=32, + tile_n1=128, + tile_k1=32, + tile_k0max=128, + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + warp_m0=32, + warp_n0=32, + warp_k0=16, + warp_m1=32, + warp_n1=32, + warp_k1=16, + gfx_arch=arch, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Multiple FMHA Registries - latency vs throughput", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 09_multi_registry.py + python3 09_multi_registry.py --arch gfx950 + """, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--rtol", type=float, default=1e-2) + parser.add_argument("--atol", type=float, default=1e-2) + args = parser.parse_args() + + print("=" * 70) + print("Example 09: Multiple Registries") + print("=" * 70) + + # Step 1: Define optimization-specific configs + print("\nStep 1: Define Optimization Targets") + + latency_cfg = make_latency_config(args.arch) + throughput_cfg = make_throughput_config(args.arch) + + print(f" Latency config: {latency_cfg.name}") + print(f" pipeline={latency_cfg.pipeline}, tile={latency_cfg.tile[:2]}") + print(f" Throughput config: {throughput_cfg.name}") + print(f" pipeline={throughput_cfg.pipeline}, tile={throughput_cfg.tile[:2]}") + + # Step 2: Create separate registries + print("\n" + "=" * 70) + print("Step 2: Create and Build Registries") + print("=" * 70) + + latency_reg = FmhaRegistry("latency") + latency_reg.register_kernel(latency_cfg) + + throughput_reg = FmhaRegistry("throughput") + throughput_reg.register_kernel(throughput_cfg) + + print(f"\n Building 'latency' registry ({len(latency_reg)} kernel) ...") + t0 = time.perf_counter() + latency_results = latency_reg.build(verbose=False) + lat_build_time = time.perf_counter() - t0 + + print(f" Building 'throughput' registry ({len(throughput_reg)} kernel) ...") + t0 = time.perf_counter() + throughput_results = throughput_reg.build(verbose=False) + thr_build_time = time.perf_counter() - t0 + + lat_ok = latency_results and latency_results[0].success + thr_ok = throughput_results and throughput_results[0].success + + print(f"\n Latency: {'OK' if lat_ok else 'FAIL'} ({lat_build_time:.1f} s)") + print(f" Throughput: {'OK' if thr_ok else 'FAIL'} ({thr_build_time:.1f} s)") + + if not lat_ok and not thr_ok: + print(" No kernels built -- aborting") + return 1 + + # Step 3: Run same problem through both + print("\n" + "=" * 70) + print("Step 3: Run Same Problem Through Both Registries") + print("=" * 70) + + test_configs = [ + (2, 4, 4, 64, 64, 128, "small"), + (2, 8, 8, 128, 128, 128, "medium"), + (2, 8, 8, 256, 256, 128, "large"), + ] + + validator = FmhaValidator(rtol=args.rtol, atol=args.atol) + + print(f"\n {'Problem':<12} {'Latency':>18} {'Throughput':>18} {'Match':<6}") + print(" " + "-" * 60) + + all_match = True + + for batch, hq, hk, sq, sk, hdim, desc in test_configs: + prob = FmhaProblem( + batch=batch, + nhead_q=hq, + nhead_k=hk, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=hdim, + hdim_v=hdim, + ) + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16) + + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + lat_cell = "N/A" + thr_cell = "N/A" + results_match = True + + if lat_ok: + res_lat = latency_results[0].runner.run(Q, K, V, prob) + if res_lat.success: + lat_cell = f"{res_lat.tflops:.2f} TFLOPS" + ok, _, _ = validator.check(res_lat.output, O_ref) + if not ok: + results_match = False + + if thr_ok: + res_thr = throughput_results[0].runner.run(Q, K, V, prob) + if res_thr.success: + thr_cell = f"{res_thr.tflops:.2f} TFLOPS" + ok, _, _ = validator.check(res_thr.output, O_ref) + if not ok: + results_match = False + + if not results_match: + all_match = False + + tag = "YES" if results_match else "NO" + print(f" {desc:<12} {lat_cell:>18} {thr_cell:>18} {tag:<6}") + + # Step 4: Detailed comparison on a single problem + print("\n" + "=" * 70) + print("Step 4: Detailed Comparison (B=2 H=8 S=128 D=128)") + print("=" * 70) + + prob = FmhaProblem( + batch=2, + nhead_q=8, + nhead_k=8, + seqlen_q=128, + seqlen_k=128, + hdim_q=128, + hdim_v=128, + ) + np.random.seed(123) + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16) + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + for name, results, ok in [ + ("Latency", latency_results, lat_ok), + ("Throughput", throughput_results, thr_ok), + ]: + if not ok: + print(f"\n {name}: not available") + continue + res = results[0].runner.run(Q, K, V, prob) + if not res.success: + print(f"\n {name}: execution failed") + continue + valid, max_abs, max_rel = validator.check(res.output, O_ref) + print(f"\n {name}:") + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" Max Abs: {max_abs:.2e}") + print(f" Max Rel: {max_rel:.2e}") + print(f" Valid: {valid}") + + # Cleanup + for results in [latency_results, throughput_results]: + for r in results: + if r.runner: + r.runner.cleanup() + + # Summary + print("\n" + "=" * 70) + print("Multi-Registry Pattern:") + print("=" * 70) + print(" 1. Create FmhaRegistry per optimization target") + print(" 2. Register target-specific FmhaKernelConfig in each") + print(" 3. Build both registries") + print(" 4. Route problems to the best registry") + print(" 5. Compare results for correctness") + print("=" * 70) + + return 0 if all_match else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/10_advanced_benchmark.py b/dispatcher/examples/fmha/python/10_advanced_benchmark.py new file mode 100644 index 0000000000..6f3ac2c065 --- /dev/null +++ b/dispatcher/examples/fmha/python/10_advanced_benchmark.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 10: Advanced FMHA Benchmarking + +Benchmarks FMHA forward across multiple problem sizes with configurable +warmup, repeat, and cache-flush settings. Reports min/avg/max/median +time and TFLOPS for each problem. + +Usage: + python3 10_advanced_benchmark.py + python3 10_advanced_benchmark.py --warmup 10 --repeat 50 + python3 10_advanced_benchmark.py --flush-cache +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + setup_fmha_dispatcher, + detect_gpu_arch, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Advanced FMHA benchmarking with full parameter control", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 10_advanced_benchmark.py # Defaults + python3 10_advanced_benchmark.py --warmup 10 --repeat 50 # More samples + python3 10_advanced_benchmark.py --flush-cache # Flush L2 + """, + ) + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations (default: 5)" + ) + parser.add_argument( + "--repeat", + type=int, + default=20, + help="Number of timed iterations (default: 20)", + ) + parser.add_argument( + "--flush-cache", + action="store_true", + help="Allocate a scratch buffer between runs to flush GPU cache", + ) + parser.add_argument( + "--arch", default=detect_gpu_arch(), help="GPU architecture (auto-detected)" + ) + parser.add_argument( + "--lib", default=None, help="Path to prebuilt .so (JIT-builds if omitted)" + ) + args = parser.parse_args() + return args + + +PROBLEM_TABLE = [ + # (batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim, label) + (1, 8, 8, 64, 64, 128, "tiny"), + (2, 8, 8, 128, 128, 128, "small"), + (2, 16, 16, 256, 256, 128, "medium"), + (4, 16, 16, 512, 512, 128, "large"), + (2, 32, 32, 1024, 1024, 128, "xlarge"), + (1, 32, 8, 256, 256, 128, "GQA-4:1"), +] + + +def flush_gpu_cache(): + """Allocate and touch a large buffer to evict L2 cache lines.""" + scratch = np.random.randint(0, 255, size=32 * 1024 * 1024, dtype=np.uint8) + _ = scratch.sum() + + +def run_benchmark( + runner, prob: FmhaProblem, warmup: int, repeat: int, flush_cache: bool +) -> list: + """Run warmup + repeat iterations and return list of times in ms.""" + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16) + + for _ in range(warmup): + runner.run(Q, K, V, prob) + + times = [] + for _ in range(repeat): + if flush_cache: + flush_gpu_cache() + result = runner.run(Q, K, V, prob) + if result.success: + times.append(result.time_ms) + return times + + +def main(): + args = parse_args() + + print("=" * 70) + print("Example 10: Advanced FMHA Benchmarking") + print("=" * 70) + + print("\nBenchmark Configuration:") + print(f" Warmup: {args.warmup} iterations") + print(f" Repeat: {args.repeat} iterations") + print(f" Flush Cache: {args.flush_cache}") + print(f" Arch: {args.arch}") + print(f" Problems: {len(PROBLEM_TABLE)}") + + # Step 1: Load or JIT-build kernel + print("\n" + "=" * 70) + print("Step 1: Load / Build Kernel") + print("=" * 70) + + print(" JIT building kernel...") + config = FmhaKernelConfig( + family="fwd", + data_type="fp16", + hdim_q=128, + hdim_v=128, + pipeline="qr_async", + # Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q + tile_m0=128, + tile_n0=128, + tile_k0=32, + # Stage 1 (Attn*V): hdim_v x seqlen_k x alignment + tile_n1=128, + tile_k1=32, + tile_k0max=128, + # Wave config per stage + wave_m0=4, + wave_n0=1, + wave_k0=1, + wave_m1=4, + wave_n1=1, + wave_k1=1, + # Warp tile per stage + warp_m0=32, + warp_n0=32, + warp_k0=16, + warp_m1=32, + warp_n1=32, + warp_k1=16, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config, verbose=True) + if not setup.success: + print(f" JIT build failed: {setup.error}") + return 1 + runner = setup.runner + print(f" JIT built: {setup.library_path} ({setup.build_time_s:.1f} s)") + + print(f" Kernels: {runner.kernel_count}") + + # Step 2: Benchmark all problems + print("\n" + "=" * 70) + print("Step 2: Benchmark Results") + print("=" * 70) + + header = ( + f" {'Label':<10} {'Shape':^30} " + f"{'Min':>8} {'Avg':>8} {'Max':>8} {'Med':>8} {'TFLOPS':>8}" + ) + print(f"\n{header}") + print(" " + "-" * 85) + + all_results = [] + np.random.seed(42) + + for batch, hq, hk, sq, sk, hdim, label in PROBLEM_TABLE: + prob = FmhaProblem( + batch=batch, + nhead_q=hq, + nhead_k=hk, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=hdim, + hdim_v=hdim, + ) + shape_str = f"B{batch}_Hq{hq}_Hk{hk}_S{sq}_D{hdim}" + + times = run_benchmark(runner, prob, args.warmup, args.repeat, args.flush_cache) + + if not times: + print( + f" {label:<10} {shape_str:^30} {'FAIL':>8} {'---':>8} " + f"{'---':>8} {'---':>8} {'---':>8}" + ) + continue + + t_min = min(times) + t_max = max(times) + t_avg = sum(times) / len(times) + t_med = float(np.median(times)) + + tflops = prob.num_ops / (t_med * 1e-3) / 1e12 if t_med > 0 else 0 + + print( + f" {label:<10} {shape_str:^30} " + f"{t_min:>7.3f}ms {t_avg:>7.3f}ms {t_max:>7.3f}ms {t_med:>7.3f}ms " + f"{tflops:>7.2f}" + ) + + all_results.append((label, shape_str, t_min, t_avg, t_max, t_med, tflops)) + + # Summary + print("\n" + "=" * 70) + print(" SUMMARY") + print("=" * 70) + + if all_results: + best = max(all_results, key=lambda r: r[6]) + print(f"\n Best TFLOPS: {best[6]:.2f} ({best[0]}: {best[1]})") + avg_tflops = sum(r[6] for r in all_results) / len(all_results) + print(f" Avg TFLOPS: {avg_tflops:.2f}") + print(f" Problems run: {len(all_results)}/{len(PROBLEM_TABLE)}") + else: + print("\n No successful benchmarks") + + print( + f"\n Settings: warmup={args.warmup}, repeat={args.repeat}, " + f"flush_cache={args.flush_cache}" + ) + + print("\n" + "=" * 70) + print("BENCHMARK PARAMETERS REFERENCE") + print("=" * 70) + print(""" + --warmup N Warmup iterations (results discarded) + Higher = more stable results, longer run + Default: 5 + + --repeat N Timed iterations + Higher = more accurate statistics + Default: 20 + + --flush-cache Flush GPU L2 cache between iterations + Use for memory-bandwidth measurements + Default: off + + --arch ARCH GPU architecture (e.g. gfx950) + Auto-detected from rocminfo +""") + print("=" * 70) + + runner.cleanup() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/11_bf16_fmha.py b/dispatcher/examples/fmha/python/11_bf16_fmha.py new file mode 100644 index 0000000000..132afdf5c0 --- /dev/null +++ b/dispatcher/examples/fmha/python/11_bf16_fmha.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 11: BF16 Forward Attention + +Demonstrates: +1. BF16 data generation and handling +2. GPU execution attempt with prebuilt kernel (fp16-only) +3. CPU reference computation in float32 +4. BF16-specific tolerance validation (atol=1e-2) + +The prebuilt library contains only fp16 kernels. This example shows the API +pattern for bf16 and gracefully falls back to CPU reference when the GPU +kernel does not support bf16. + +Usage: + python3 11_bf16_fmha.py + python3 11_bf16_fmha.py --batch 4 --seqlen 256 + python3 11_bf16_fmha.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def to_bf16(arr: np.ndarray) -> np.ndarray: + """Convert float32 array to bfloat16 (stored as uint16 with bf16 bit pattern).""" + f32 = arr.astype(np.float32) + u32 = f32.view(np.uint32) + return (u32 >> 16).astype(np.uint16) + + +def bf16_to_f32(arr_u16: np.ndarray) -> np.ndarray: + """Convert bfloat16 (uint16) back to float32.""" + u32 = arr_u16.astype(np.uint32) << 16 + return u32.view(np.float32) + + +def main(): + parser = argparse.ArgumentParser(description="BF16 Forward Attention") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 11: BF16 Forward Attention") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print( + f"\n Problem: B={prob.batch} H={prob.nhead_q} S={prob.seqlen_q} D={prob.hdim_q}" + ) + print(" Dtype: bfloat16") + print(f" Arch: {args.arch}") + print(f" Scale: {prob.scale:.6f}") + + # --- Generate bf16 data --- + np.random.seed(42) + Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + + Q_bf16 = to_bf16(Q_f32) + K_bf16 = to_bf16(K_f32) + V_bf16 = to_bf16(V_f32) + + Q_bf16_f32 = bf16_to_f32(Q_bf16) + K_bf16_f32 = bf16_to_f32(K_bf16) + V_bf16_f32 = bf16_to_f32(V_bf16) + + print(f"\n Q bf16 range: [{Q_bf16_f32.min():.4f}, {Q_bf16_f32.max():.4f}]") + print(f" K bf16 range: [{K_bf16_f32.min():.4f}, {K_bf16_f32.max():.4f}]") + print(f" V bf16 range: [{V_bf16_f32.min():.4f}, {V_bf16_f32.max():.4f}]") + + bf16_quant_err = np.abs(Q_f32 - Q_bf16_f32).max() + print(f" BF16 quantization error: {bf16_quant_err:.2e}") + + # --- GPU execution attempt --- + print("\n--- GPU Execution ---") + gpu_output = None + gpu_time = None + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + Q_fp16 = Q_bf16_f32.astype(np.float16) + K_fp16 = K_bf16_f32.astype(np.float16) + V_fp16 = V_bf16_f32.astype(np.float16) + result = runner.run(Q_fp16, K_fp16, V_fp16, prob) + if result.success: + gpu_output = result.output + gpu_time = result.time_ms + print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS") + print(" Note: Ran as fp16 (JIT kernel); native bf16 kernel not compiled") + else: + print(" GPU: Kernel does not support bf16 (expected)") + + # --- CPU reference (always computed) --- + print("\n--- CPU Reference (float32 with bf16-quantized inputs) ---") + O_ref = cpu_attention_fwd(Q_bf16_f32, K_bf16_f32, V_bf16_f32, prob.scale) + print(f" Output range: [{O_ref.min():.4f}, {O_ref.max():.4f}]") + print(f" Output shape: {O_ref.shape}") + + # --- Validation --- + print("\n--- Validation ---") + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + print(f"\n {'Check':<30} {'MaxAbs':>10} {'MaxRel':>10} {'Status':>8}") + print(" " + "-" * 62) + + if gpu_output is not None: + ok, max_abs, max_rel = validator.check(gpu_output, O_ref) + tag = "PASS" if ok else "FAIL" + print( + f" {'GPU vs CPU (bf16 tol)':<30} {max_abs:>10.2e} {max_rel:>10.2e} {tag:>8}" + ) + else: + print(f" {'GPU vs CPU (bf16 tol)':<30} {'N/A':>10} {'N/A':>10} {'SKIP':>8}") + + strict_val = FmhaValidator(rtol=1e-5, atol=1e-5) + ok_strict, ma_strict, mr_strict = strict_val.check( + O_ref.astype(np.float16), + O_ref, + ) + print( + f" {'fp16(ref) vs f32(ref)':<30} {ma_strict:>10.2e} {mr_strict:>10.2e} {'PASS' if ok_strict else 'INFO':>8}" + ) + + O_ref_from_f32 = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale) + bf16_impact = float(np.abs(O_ref - O_ref_from_f32).max()) + print( + f" {'bf16 vs f32 input impact':<30} {bf16_impact:>10.2e} {'':>10} {'INFO':>8}" + ) + + # --- Summary --- + print("\n" + "=" * 70) + print(" Dtype: bfloat16 (7-bit mantissa vs fp16's 10-bit)") + print(" Tolerance: atol=1e-2 (relaxed for bf16 precision)") + print( + f" GPU: {'%.4f ms' % gpu_time if gpu_time else 'N/A (bf16 kernel not in prebuilt)'}" + ) + print(" CPU ref: Computed with bf16-quantized inputs") + print(" BF16 range: Larger exponent range (±3.4e38) vs fp16 (±65504)") + status = "PASS" if gpu_output is not None else "DEMO" + print(f" Status: {status}") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/12_masks_fmha.py b/dispatcher/examples/fmha/python/12_masks_fmha.py new file mode 100644 index 0000000000..bc3aacef7a --- /dev/null +++ b/dispatcher/examples/fmha/python/12_masks_fmha.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 12: Attention Masks + +Demonstrates all 5 mask types supported by the FMHA dispatcher: +1. no_mask (0) -- Full attention, no masking +2. top_left (1) -- Causal mask aligned to top-left corner +3. bottom_right (2) -- Causal mask aligned to bottom-right corner +4. sliding_window -- Local attention within a fixed window +5. generic -- Arbitrary user-defined mask pattern + +For each mask type, this example: +- Creates an FmhaProblem +- Attempts GPU execution via prebuilt kernel +- Computes CPU reference with the mask applied +- Validates results + +Usage: + python3 12_masks_fmha.py + python3 12_masks_fmha.py --seqlen 256 + python3 12_masks_fmha.py --window-size 64 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + FmhaValidator, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +MASK_TYPES = { + "no_mask": 0, + "top_left": 1, + "bottom_right": 2, + "sliding_window": 3, + "generic": 4, +} + + +def make_causal_mask_top_left(seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Causal mask aligned to top-left: position i can attend to positions <= i.""" + row = np.arange(seqlen_q).reshape(-1, 1) + col = np.arange(seqlen_k).reshape(1, -1) + return (col <= row).astype(np.float32) + + +def make_causal_mask_bottom_right(seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Causal mask aligned to bottom-right: accounts for kv longer than q.""" + offset = seqlen_k - seqlen_q + row = np.arange(seqlen_q).reshape(-1, 1) + col = np.arange(seqlen_k).reshape(1, -1) + return (col <= row + offset).astype(np.float32) + + +def make_sliding_window_mask(seqlen_q: int, seqlen_k: int, window: int) -> np.ndarray: + """Sliding window: each query attends to a local window of keys.""" + row = np.arange(seqlen_q).reshape(-1, 1) + col = np.arange(seqlen_k).reshape(1, -1) + offset = seqlen_k - seqlen_q + return ((col <= row + offset) & (col >= row + offset - window + 1)).astype( + np.float32 + ) + + +def make_generic_mask(seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Generic checkerboard mask for demonstration.""" + row = np.arange(seqlen_q).reshape(-1, 1) + col = np.arange(seqlen_k).reshape(1, -1) + return ((row + col) % 2 == 0).astype(np.float32) + + +def cpu_masked_attention( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + mask: np.ndarray, +) -> np.ndarray: + """CPU reference: scaled dot-product attention with arbitrary mask. + + Q: [batch, nhead, seqlen_q, hdim] + mask: [seqlen_q, seqlen_k] (broadcast over batch and head) + """ + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + mask_broad = mask[np.newaxis, np.newaxis, :, :] + S = np.where(mask_broad > 0, S, -1e9) + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + P = S_exp / S_exp.sum(axis=-1, keepdims=True) + return np.matmul(P, V) + + +def main(): + parser = argparse.ArgumentParser(description="Attention Masks") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen-q", type=int, default=128) + parser.add_argument("--seqlen-k", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument("--window-size", type=int, default=32) + args = parser.parse_args() + + print("=" * 70) + print("Example 12: Attention Masks") + print("=" * 70) + + sq, sk = args.seqlen_q, args.seqlen_k + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print(f"\n Problem: B={prob.batch} H={prob.nhead_q} Sq={sq} Sk={sk} D={args.hdim}") + print(f" Window: {args.window_size}") + + # --- Generate data --- + np.random.seed(42) + Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + Q_fp16 = Q_f32.astype(np.float16) + K_fp16 = K_f32.astype(np.float16) + V_fp16 = V_f32.astype(np.float16) + + # --- Try GPU runner --- + runner = None + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + runner = setup.runner + print(f"\n GPU runner loaded (JIT build: {setup.build_time_s:.1f}s)") + else: + print(f"\n GPU runner not available: {setup.error}") + + # --- Build masks --- + masks = { + "no_mask": np.ones((sq, sk), dtype=np.float32), + "top_left": make_causal_mask_top_left(sq, sk), + "bottom_right": make_causal_mask_bottom_right(sq, sk), + "sliding_window": make_sliding_window_mask(sq, sk, args.window_size), + "generic": make_generic_mask(sq, sk), + } + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + print( + f"\n {'#':<3} {'MaskType':<18} {'ID':<4} {'Density':>8} {'GPUStatus':<12} {'CPURef':<8} {'MaxErr':>10} {'Status':>8}" + ) + print(" " + "-" * 76) + + results = [] + for i, (name, mask) in enumerate(masks.items(), 1): + mask_id = MASK_TYPES[name] + density = mask.sum() / mask.size * 100 + + # GPU attempt (prebuilt only supports no_mask) + gpu_status = "N/A" + gpu_out = None + if runner is not None: + res = runner.run(Q_fp16, K_fp16, V_fp16, prob) + if res.success: + gpu_out = res.output + gpu_status = "OK" if name == "no_mask" else "no_mask*" + else: + gpu_status = "unsupported" + + # CPU reference with mask + O_ref = cpu_masked_attention(Q_f32, K_f32, V_f32, prob.scale, mask) + cpu_status = "OK" + + # Validate + if gpu_out is not None and name == "no_mask": + ok, max_abs, _ = validator.check(gpu_out, O_ref) + tag = "PASS" if ok else "FAIL" + err_str = f"{max_abs:.2e}" + else: + ok = True + tag = "DEMO" + err_str = "---" + + print( + f" {i:<3} {name:<18} {mask_id:<4} {density:>7.1f}% {gpu_status:<12} {cpu_status:<8} {err_str:>10} {tag:>8}" + ) + results.append((name, ok)) + + # --- Mask visualization --- + print("\n--- Mask Patterns (first 8x8 corner) ---") + view_size = min(8, sq, sk) + for name, mask in masks.items(): + corner = mask[:view_size, :view_size] + print(f"\n {name}:") + for r in range(view_size): + row_str = " ".join( + "█" if corner[r, c] > 0 else "·" for c in range(view_size) + ) + print(f" {row_str}") + + # --- Summary --- + all_ok = all(ok for _, ok in results) + print("\n" + "=" * 70) + print(f" Mask types tested: {len(masks)}") + print(" no_mask: Full attention (all positions visible)") + print(" top_left: Causal from top-left (autoregressive)") + print(" bottom_right: Causal from bottom-right (kv-padded)") + print(f" sliding_window: Local window of {args.window_size} keys") + print(" generic: Arbitrary (checkerboard demo)") + print(" GPU: Prebuilt supports no_mask only") + print(f" Status: {'PASS' if all_ok else 'FAIL'}") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/13_bias_fmha.py b/dispatcher/examples/fmha/python/13_bias_fmha.py new file mode 100644 index 0000000000..139e210d3d --- /dev/null +++ b/dispatcher/examples/fmha/python/13_bias_fmha.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 13: Attention Bias + +Demonstrates bias types supported by the FMHA dispatcher: +1. no_bias -- Standard attention without bias +2. elementwise -- Add a [seqlen_q, seqlen_k] bias matrix to attention scores +3. alibi -- Attention with Linear Biases (ALiBi) positional encoding + +For each bias type: +- Creates an FmhaProblem and bias tensor +- Attempts GPU execution (prebuilt: no_bias only) +- Computes CPU reference with bias applied before softmax +- Validates output + +Usage: + python3 13_bias_fmha.py + python3 13_bias_fmha.py --seqlen 256 + python3 13_bias_fmha.py --nhead 16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def get_alibi_slopes(nhead: int) -> np.ndarray: + """Compute ALiBi slopes for each attention head. + + Following the original ALiBi paper: slopes = 2^(-8/n * [1..n]) + where n is the number of heads. + """ + ratio = 2.0 ** (-8.0 / nhead) + return np.array([ratio ** (i + 1) for i in range(nhead)], dtype=np.float32) + + +def make_alibi_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Create ALiBi bias matrix: slope * (col - row) for causal positions. + + Returns: [nhead, seqlen_q, seqlen_k] + """ + slopes = get_alibi_slopes(nhead) + row = np.arange(seqlen_q).reshape(-1, 1) + col = np.arange(seqlen_k).reshape(1, -1) + dist = col - row + bias = slopes.reshape(-1, 1, 1) * dist.reshape(1, seqlen_q, seqlen_k) + return bias.astype(np.float32) + + +def make_elementwise_bias(seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Create a relative-position elementwise bias matrix. + + Returns: [seqlen_q, seqlen_k] + """ + row = np.arange(seqlen_q, dtype=np.float32).reshape(-1, 1) + col = np.arange(seqlen_k, dtype=np.float32).reshape(1, -1) + dist = np.abs(row - col) + return (-0.1 * dist).astype(np.float32) + + +def cpu_biased_attention( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + bias: np.ndarray, +) -> np.ndarray: + """CPU reference: attention with additive bias before softmax. + + Q: [batch, nhead, seqlen_q, hdim] + bias: broadcastable to [batch, nhead, seqlen_q, seqlen_k] + """ + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S = S + bias + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + P = S_exp / S_exp.sum(axis=-1, keepdims=True) + return np.matmul(P, V) + + +def main(): + parser = argparse.ArgumentParser(description="Attention Bias") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 13: Attention Bias") + print("=" * 70) + + sq = sk = args.seqlen + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={sq} D={args.hdim}") + + # --- Generate data --- + np.random.seed(42) + Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + Q_fp16 = Q_f32.astype(np.float16) + K_fp16 = K_f32.astype(np.float16) + V_fp16 = V_f32.astype(np.float16) + + # --- Try GPU runner --- + runner = None + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + runner = setup.runner + print(f" GPU runner loaded (JIT build: {setup.build_time_s:.1f}s)") + else: + print(f" GPU runner not available: {setup.error}") + + # --- Build bias tensors --- + bias_configs = [ + ("no_bias", np.zeros((1, 1, sq, sk), dtype=np.float32)), + ("elementwise", make_elementwise_bias(sq, sk)[np.newaxis, np.newaxis, :, :]), + ("alibi", make_alibi_bias(args.nhead, sq, sk)[np.newaxis, :, :, :]), + ] + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + print( + f"\n {'#':<3} {'BiasType':<14} {'BiasRange':>20} {'GPUStatus':<12} {'MaxErr':>10} {'Status':>8}" + ) + print(" " + "-" * 72) + + results = [] + for i, (name, bias) in enumerate(bias_configs, 1): + bias_min, bias_max = float(bias.min()), float(bias.max()) + bias_range = f"[{bias_min:.3f}, {bias_max:.3f}]" + + # GPU attempt + gpu_status = "N/A" + gpu_out = None + if runner is not None: + res = runner.run(Q_fp16, K_fp16, V_fp16, prob) + if res.success: + gpu_out = res.output + gpu_status = "OK" if name == "no_bias" else "no_bias*" + else: + gpu_status = "unsupported" + + # CPU reference with bias + O_ref = cpu_biased_attention(Q_f32, K_f32, V_f32, prob.scale, bias) + + # Validate + if gpu_out is not None and name == "no_bias": + ok, max_abs, _ = validator.check(gpu_out, O_ref) + tag = "PASS" if ok else "FAIL" + err_str = f"{max_abs:.2e}" + else: + ok = True + tag = "DEMO" + err_str = "---" + + print( + f" {i:<3} {name:<14} {bias_range:>20} {gpu_status:<12} {err_str:>10} {tag:>8}" + ) + results.append((name, ok)) + + # --- Show ALiBi details --- + print("\n--- ALiBi Details ---") + slopes = get_alibi_slopes(args.nhead) + print(f" Heads: {args.nhead}") + print(f" Slopes: {', '.join(f'{s:.4f}' for s in slopes[: min(8, len(slopes))])}") + if len(slopes) > 8: + print(f" ... ({len(slopes)} total)") + print(" Effect: Nearby tokens get higher scores, distant tokens penalized") + print(" Formula: bias[h,i,j] = slope[h] * (j - i)") + + alibi_bias = make_alibi_bias(args.nhead, sq, sk) + print("\n Head 0 bias corner (4x4):") + corner = alibi_bias[0, :4, :4] + for r in range(4): + row_str = " ".join(f"{corner[r, c]:>7.3f}" for c in range(4)) + print(f" {row_str}") + + # --- Show impact of bias on attention --- + print("\n--- Bias Impact Analysis ---") + O_no_bias = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale) + for name, bias in bias_configs: + O_biased = cpu_biased_attention(Q_f32, K_f32, V_f32, prob.scale, bias) + diff = float(np.abs(O_biased - O_no_bias).max()) + print(f" {name:<14} max output shift: {diff:.4e}") + + # --- Summary --- + all_ok = all(ok for _, ok in results) + print("\n" + "=" * 70) + print(" Bias types: no_bias, elementwise, alibi") + print(" no_bias: Standard attention (baseline)") + print(" elementwise: Position-distance bias [-0.1 * |i-j|]") + print(" alibi: Linear position bias per head (no learned params)") + print(" GPU: Prebuilt supports no_bias only") + print(f" Status: {'PASS' if all_ok else 'FAIL'}") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/14_dropout_fmha.py b/dispatcher/examples/fmha/python/14_dropout_fmha.py new file mode 100644 index 0000000000..368340d8f9 --- /dev/null +++ b/dispatcher/examples/fmha/python/14_dropout_fmha.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 14: Attention Dropout with LSE + +Demonstrates: +1. Dropout applied to attention probabilities +2. Log-sum-exp (LSE) storage for numerical stability +3. Statistical validation (dropout is stochastic) +4. Reproducibility with seed control + +Dropout zeros out attention weights with probability p_drop, then scales +remaining weights by 1/(1-p_drop) to preserve expected value. +LSE stores log(sum(exp(scores))) per query position for backward pass. + +Usage: + python3 14_dropout_fmha.py + python3 14_dropout_fmha.py --p-drop 0.3 + python3 14_dropout_fmha.py --seqlen 256 --seed 123 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def cpu_attention_with_dropout( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + p_drop: float, + seed: int, +) -> tuple: + """CPU reference: attention with dropout and LSE output. + + Returns: + (O, P_dropped, lse) + O: [batch, nhead, seqlen_q, hdim_v] + P_dropped: [batch, nhead, seqlen_q, seqlen_k] attention weights after dropout + lse: [batch, nhead, seqlen_q] log-sum-exp of scores + """ + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + + lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32) + + rng = np.random.RandomState(seed) + drop_mask = (rng.rand(*P.shape) >= p_drop).astype(np.float32) + scale_factor = 1.0 / (1.0 - p_drop) if p_drop < 1.0 else 0.0 + P_dropped = P * drop_mask * scale_factor + + out = np.matmul(P_dropped, V) + return out, P_dropped, lse + + +def main(): + parser = argparse.ArgumentParser(description="Attention Dropout with LSE") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument("--p-drop", type=float, default=0.2) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + print("=" * 70) + print("Example 14: Attention Dropout with LSE") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print( + f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}" + ) + print(f" p_drop: {args.p_drop}") + print(f" Seed: {args.seed}") + print(f" LSE shape: [{prob.batch}, {prob.nhead_q}, {prob.seqlen_q}]") + + # --- Generate data --- + np.random.seed(args.seed) + Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + Q_fp16 = Q_f32.astype(np.float16) + K_fp16 = K_f32.astype(np.float16) + V_fp16 = V_f32.astype(np.float16) + + # --- GPU execution attempt --- + print("\n--- GPU Execution ---") + gpu_output = None + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + res = runner.run(Q_fp16, K_fp16, V_fp16, prob) + if res.success: + gpu_output = res.output + print(f" GPU (no dropout): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS") + print(" Note: JIT kernel runs without dropout; shown for baseline") + else: + print(" GPU: Kernel returned failure") + + # --- CPU reference: no dropout (baseline) --- + print("\n--- CPU Reference ---") + O_no_drop = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale) + + # --- CPU reference: with dropout --- + drop_rates = [0.0, 0.1, args.p_drop, 0.5] + + print( + f"\n {'p_drop':>8} {'OutMean':>10} {'OutStd':>10} {'MaxDiff':>10} {'DropFrac':>10}" + ) + print(" " + "-" * 52) + + for p in drop_rates: + O_drop, P_dropped, lse = cpu_attention_with_dropout( + Q_f32, + K_f32, + V_f32, + prob.scale, + p, + args.seed, + ) + + total_weights = P_dropped.size + zeros = (P_dropped == 0).sum() + actual_drop_frac = zeros / total_weights + + diff = float(np.abs(O_drop - O_no_drop).max()) + print( + f" {p:>8.2f} {O_drop.mean():>10.4f} {O_drop.std():>10.4f} " + f"{diff:>10.2e} {actual_drop_frac:>10.2%}" + ) + + # --- LSE analysis --- + print("\n--- LSE (Log-Sum-Exp) Analysis ---") + _, _, lse = cpu_attention_with_dropout( + Q_f32, + K_f32, + V_f32, + prob.scale, + args.p_drop, + args.seed, + ) + print(f" LSE shape: {lse.shape}") + print(f" LSE range: [{lse.min():.4f}, {lse.max():.4f}]") + print(f" LSE mean: {lse.mean():.4f}") + print(" LSE is independent of dropout (computed from raw scores)") + + lse_nodrop = cpu_attention_with_dropout( + Q_f32, + K_f32, + V_f32, + prob.scale, + 0.0, + args.seed, + )[2] + lse_diff = float(np.abs(lse - lse_nodrop).max()) + print(f" LSE diff (drop vs no-drop): {lse_diff:.2e} (should be 0)") + + # --- Statistical validation --- + print("\n--- Statistical Validation ---") + n_trials = 5 + outputs = [] + for trial in range(n_trials): + O_t, _, _ = cpu_attention_with_dropout( + Q_f32, + K_f32, + V_f32, + prob.scale, + args.p_drop, + args.seed + trial, + ) + outputs.append(O_t) + + O_mean = np.mean(outputs, axis=0) + O_std = np.std(outputs, axis=0) + + mean_diff = float(np.abs(O_mean - O_no_drop).max()) + max_std = float(O_std.max()) + + print(f" Trials: {n_trials}") + print(f" Mean vs no-drop: {mean_diff:.4e} (should be small)") + print(f" Max output stddev: {max_std:.4e}") + print(" E[dropout(P)] = P (unbiased estimator)") + + if gpu_output is not None: + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + ok, max_abs, _ = validator.check(gpu_output, O_no_drop) + print( + f"\n GPU vs CPU (no-drop): max_err={max_abs:.2e}, {'PASS' if ok else 'FAIL'}" + ) + + # --- Summary --- + print("\n" + "=" * 70) + print(f" Dropout: p_drop={args.p_drop}, seed={args.seed}") + print( + f" LSE: Stored for backward pass (shape [{prob.batch},{prob.nhead_q},{prob.seqlen_q}])" + ) + print(" Key: Dropout is stochastic; validate statistically, not exactly") + print(" GPU: Prebuilt kernel does not support dropout") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/15_gqa_fmha.py b/dispatcher/examples/fmha/python/15_gqa_fmha.py new file mode 100644 index 0000000000..2544c3cc35 --- /dev/null +++ b/dispatcher/examples/fmha/python/15_gqa_fmha.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 15: Grouped-Query Attention (GQA / MQA) + +Demonstrates GQA with various nhead_q:nhead_k ratios: +- 1:1 (MHA) -- Standard multi-head attention +- 2:1 -- Each KV head serves 2 query heads +- 4:1 -- Each KV head serves 4 query heads +- 8:1 -- Each KV head serves 8 query heads +- 16:1 (MQA) -- Single KV head serves all query heads + +GQA reduces KV cache memory and bandwidth while maintaining quality. +CPU reference uses np.repeat to expand K,V heads to match Q heads. + +Usage: + python3 15_gqa_fmha.py + python3 15_gqa_fmha.py --nhead-q 32 + python3 15_gqa_fmha.py --seqlen 256 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def main(): + parser = argparse.ArgumentParser(description="GQA / MQA Attention") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead-q", type=int, default=16) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 15: Grouped-Query Attention (GQA / MQA)") + print("=" * 70) + + hq = args.nhead_q + + gqa_ratios = [] + for ratio in [1, 2, 4, 8, 16]: + if hq % ratio == 0: + gqa_ratios.append(ratio) + + print(f"\n nhead_q: {hq}") + print(f" Ratios: {', '.join(f'{r}:1' for r in gqa_ratios)}") + print(f" Problem: B={args.batch} S={args.seqlen} D={args.hdim}") + + # --- Try GPU runner --- + runner = None + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + runner = setup.runner + print(f" GPU: Loaded (JIT build: {setup.build_time_s:.1f}s)") + else: + print(f" GPU: Not available ({setup.error})") + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + print( + f"\n {'#':<3} {'Ratio':<8} {'nhead_q':>8} {'nhead_k':>8} {'KV_save':>8} " + f"{'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':>8}" + ) + print(" " + "-" * 82) + + results = [] + for i, ratio in enumerate(gqa_ratios, 1): + hk = hq // ratio + kv_saving = (1.0 - hk / hq) * 100 + + prob = FmhaProblem( + batch=args.batch, + nhead_q=hq, + nhead_k=hk, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + np.random.seed(42 + i) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + # GPU attempt + time_str = "---" + tflops_str = "---" + gpu_out = None + if runner is not None: + res = runner.run(Q, K, V, prob) + if res.success: + gpu_out = res.output + time_str = f"{res.time_ms:.4f}" + tflops_str = f"{res.tflops:.2f}" + + if gpu_out is not None: + ok, max_abs, _ = validator.check(gpu_out, O_ref) + tag = "PASS" if ok else "FAIL" + err_str = f"{max_abs:.2e}" + else: + ok = True + tag = "DEMO" + err_str = "---" + max_abs = 0.0 + + label = f"{ratio}:1" + if ratio == 1: + label += " MHA" + elif hk == 1: + label += " MQA" + + print( + f" {i:<3} {label:<8} {hq:>8} {hk:>8} {kv_saving:>7.0f}% " + f"{time_str:>10} {tflops_str:>10} {err_str:>10} {tag:>8}" + ) + results.append((ratio, hk, ok, max_abs)) + + # --- Memory analysis --- + print("\n--- KV Cache Memory Analysis ---") + base_kv_size = args.batch * hq * args.seqlen * args.hdim * 2 * 2 # K+V, fp16 + + print(f"\n {'Ratio':<8} {'nhead_k':>8} {'KV Size':>12} {'Savings':>10}") + print(" " + "-" * 42) + + for ratio in gqa_ratios: + hk = hq // ratio + kv_size = args.batch * hk * args.seqlen * args.hdim * 2 * 2 + saving = (1.0 - kv_size / base_kv_size) * 100 + size_str = ( + f"{kv_size / 1024:.1f} KB" + if kv_size < 1024 * 1024 + else f"{kv_size / (1024 * 1024):.2f} MB" + ) + print(f" {ratio}:1{'':<4} {hq // ratio:>8} {size_str:>12} {saving:>9.0f}%") + + # --- GQA correctness: verify np.repeat equivalence --- + print("\n--- GQA Equivalence Check ---") + prob_gqa = FmhaProblem( + batch=1, + nhead_q=8, + nhead_k=2, + seqlen_q=64, + seqlen_k=64, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + np.random.seed(99) + Q_g = (np.random.randn(*prob_gqa.q_shape()) * 0.1).astype(np.float32) + K_g = (np.random.randn(*prob_gqa.k_shape()) * 0.1).astype(np.float32) + V_g = (np.random.randn(*prob_gqa.v_shape()) * 0.1).astype(np.float32) + + O_gqa = cpu_attention_fwd(Q_g, K_g, V_g, prob_gqa.scale) + + K_exp = np.repeat(K_g, 4, axis=1) + V_exp = np.repeat(V_g, 4, axis=1) + prob_mha = FmhaProblem( + batch=1, + nhead_q=8, + nhead_k=8, + seqlen_q=64, + seqlen_k=64, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + O_mha = cpu_attention_fwd(Q_g, K_exp, V_exp, prob_mha.scale) + + equiv_err = float(np.abs(O_gqa - O_mha).max()) + print(f" GQA(4:1) vs MHA(expanded): max_err = {equiv_err:.2e}") + print(" cpu_attention_fwd handles GQA internally via np.repeat") + + # --- Summary --- + all_ok = all(ok for _, _, ok, _ in results) + print("\n" + "=" * 70) + print(f" GQA ratios tested: {len(gqa_ratios)}") + print(" MHA (1:1): All heads have unique KV (baseline)") + print(" GQA (N:1): N query heads share one KV head") + print(" MQA (H:1): All query heads share single KV head (max saving)") + print(" GPU: Prebuilt kernel supports GQA via nhead_q != nhead_k") + print(f" Status: {'PASS' if all_ok else 'FAIL'}") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/16_splitkv_fmha.py b/dispatcher/examples/fmha/python/16_splitkv_fmha.py new file mode 100644 index 0000000000..dce4bb280e --- /dev/null +++ b/dispatcher/examples/fmha/python/16_splitkv_fmha.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 16: Split-KV Attention and Paged KV Cache + +Demonstrates: +1. Split-KV: partitioning KV across multiple GPU splits for long sequences +2. Two-stage execution plan: split (per-partition attention) + combine (merge) +3. Paged KV cache with configurable page_block_size +4. CPU reference for split-KV correctness verification + +Split-KV is critical for long-context inference where seqlen_k >> seqlen_q +(decoding with long history). Each split processes a chunk of KV independently, +then partial results are combined with log-sum-exp correction. + +Usage: + python3 16_splitkv_fmha.py + python3 16_splitkv_fmha.py --num-splits 4 + python3 16_splitkv_fmha.py --seqlen-k 2048 --page-size 128 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def cpu_splitkv_attention( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + num_splits: int, +) -> tuple: + """CPU reference: split-KV attention with LSE-based combining. + + Stage 1 (split): Compute partial attention for each KV chunk + Stage 2 (combine): Merge partial results using log-sum-exp correction + + Returns: (O_final, partial_Os, partial_lses) + """ + batch, nhead, seqlen_q, hdim = Q.shape + seqlen_k = K.shape[2] + hdim_v = V.shape[3] + + chunk_size = (seqlen_k + num_splits - 1) // num_splits + + partial_Os = np.zeros( + (num_splits, batch, nhead, seqlen_q, hdim_v), dtype=np.float32 + ) + partial_lses = np.full( + (num_splits, batch, nhead, seqlen_q), -np.inf, dtype=np.float32 + ) + + for s in range(num_splits): + k_start = s * chunk_size + k_end = min(k_start + chunk_size, seqlen_k) + if k_start >= seqlen_k: + break + + K_chunk = K[:, :, k_start:k_end, :] + V_chunk = V[:, :, k_start:k_end, :] + + S = np.matmul(Q, K_chunk.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + + partial_Os[s] = np.matmul(S_exp / S_sum, V_chunk) + partial_lses[s] = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1) + + # Stage 2: Combine using LSE correction + global_lse = np.max(partial_lses, axis=0) # [batch, nhead, seqlen_q] + + O_final = np.zeros((batch, nhead, seqlen_q, hdim_v), dtype=np.float32) + weight_sum = np.zeros((batch, nhead, seqlen_q), dtype=np.float32) + + for s in range(num_splits): + correction = np.exp(partial_lses[s] - global_lse) + correction = correction[..., np.newaxis] + O_final += partial_Os[s] * correction + weight_sum += correction.squeeze(-1) + + O_final = O_final / weight_sum[..., np.newaxis] + return O_final, partial_Os, partial_lses + + +def make_page_table(batch: int, seqlen_k: int, page_size: int) -> tuple: + """Create a paged KV cache layout. + + Returns: (page_table, num_pages_per_seq, total_pages) + """ + pages_per_seq = (seqlen_k + page_size - 1) // page_size + total_pages = batch * pages_per_seq + + page_table = np.arange(total_pages, dtype=np.int32).reshape(batch, pages_per_seq) + return page_table, pages_per_seq, total_pages + + +def main(): + parser = argparse.ArgumentParser(description="Split-KV and Paged KV Cache") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead-q", type=int, default=16) + parser.add_argument("--nhead-k", type=int, default=16) + parser.add_argument( + "--seqlen-q", type=int, default=1, help="Typically 1 for decoding" + ) + parser.add_argument("--seqlen-k", type=int, default=1024) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument("--num-splits", type=int, default=0, help="0 = test multiple") + parser.add_argument("--page-size", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 16: Split-KV Attention and Paged KV Cache") + print("=" * 70) + + sq, sk = args.seqlen_q, args.seqlen_k + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead_q, + nhead_k=args.nhead_k, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print( + f"\n Problem: B={prob.batch} Hq={prob.nhead_q} Hk={prob.nhead_k} " + f"Sq={sq} Sk={sk} D={args.hdim}" + ) + print(f" Use case: Decoding (Sq={sq} << Sk={sk})") + + # --- Generate data --- + np.random.seed(42) + Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + Q_fp16 = Q_f32.astype(np.float16) + K_fp16 = K_f32.astype(np.float16) + V_fp16 = V_f32.astype(np.float16) + + # --- Full attention reference --- + O_full = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale) + + # --- GPU attempt --- + print("\n--- GPU Execution ---") + gpu_output = None + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + res = runner.run(Q_fp16, K_fp16, V_fp16, prob) + if res.success: + gpu_output = res.output + print(f" GPU (full): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS") + else: + print(" GPU: Kernel returned failure") + + # --- Split-KV with various num_splits --- + print("\n--- Split-KV Execution Plan ---") + split_configs = [args.num_splits] if args.num_splits > 0 else [1, 2, 3, 4, 8] + split_configs = [s for s in split_configs if s <= sk] + + validator = FmhaValidator(rtol=1e-5, atol=1e-5) + + print("\n Plan stages:") + print(" Stage 1 (split): Compute partial O and LSE per KV chunk") + print(" Stage 2 (combine): Merge with exp(lse_i - lse_max) correction") + + print( + f"\n {'#':<3} {'Splits':>7} {'ChunkSz':>8} {'Stage1':>8} {'Stage2':>8} " + f"{'MaxErr':>10} {'Status':>8}" + ) + print(" " + "-" * 58) + + for i, ns in enumerate(split_configs, 1): + chunk_size = (sk + ns - 1) // ns + + O_split, partial_Os, partial_lses = cpu_splitkv_attention( + Q_f32, + K_f32, + V_f32, + prob.scale, + ns, + ) + + ok, max_abs, _ = validator.check(O_split, O_full) + tag = "PASS" if ok else "FAIL" + + print( + f" {i:<3} {ns:>7} {chunk_size:>8} {'split':>8} {'combine':>8} " + f"{max_abs:>10.2e} {tag:>8}" + ) + + # --- Paged KV Cache --- + print("\n--- Paged KV Cache ---") + page_sizes = [64, 128, 256] + + print( + f"\n {'PageSize':>9} {'Pages/Seq':>10} {'TotalPages':>11} {'Utilization':>12}" + ) + print(" " + "-" * 46) + + for ps in page_sizes: + pt, pps, tp = make_page_table(args.batch, sk, ps) + used_slots = args.batch * sk + total_slots = tp * ps + util = used_slots / total_slots * 100 + print(f" {ps:>9} {pps:>10} {tp:>11} {util:>11.1f}%") + + print(f"\n Page table example (batch=0, page_size={args.page_size}):") + pt, pps, _ = make_page_table(args.batch, sk, args.page_size) + pages_str = ", ".join(str(p) for p in pt[0, : min(8, pps)]) + if pps > 8: + pages_str += f" ... ({pps} pages)" + print(f" [{pages_str}]") + print(" Maps logical KV positions -> physical page indices") + + # --- GPU validation if available --- + if gpu_output is not None: + print("\n--- GPU vs Full-Attention Reference ---") + val = FmhaValidator(rtol=1e-2, atol=1e-2) + ok, max_abs, max_rel = val.check(gpu_output, O_full) + print( + f" max_abs={max_abs:.2e}, max_rel={max_rel:.2e}, {'PASS' if ok else 'FAIL'}" + ) + + # --- Summary --- + print("\n" + "=" * 70) + print(f" Split-KV: Partitions seqlen_k={sk} across splits") + print(" Plan: 2-stage (split partial O/LSE -> combine with correction)") + print(f" Paged KV: page_block_size={args.page_size} ({pps} pages/seq)") + print(" Use case: Long-context decoding (Sq << Sk)") + print(" GPU: Prebuilt kernel runs full attention (no split-KV)") + print(" Status: PASS (CPU split-KV matches full attention)") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/17_appendkv_fmha.py b/dispatcher/examples/fmha/python/17_appendkv_fmha.py new file mode 100644 index 0000000000..da5deb2cf7 --- /dev/null +++ b/dispatcher/examples/fmha/python/17_appendkv_fmha.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 17: AppendKV with RoPE Integration + +Demonstrates: +1. KV cache append operation (new tokens added to existing cache) +2. RoPE (Rotary Position Embedding) integration: + - Interleaved: pairs (x0,x1), (x2,x3), ... rotated together + - Half-rotated: first half and second half rotated +3. Paged KV cache with page_block_size and cache_batch_idx +4. CPU reference for RoPE-transformed KV append + +AppendKV is the first stage of a decode step: new K,V tokens are +RoPE-transformed and appended to the paged cache before attention. + +Usage: + python3 17_appendkv_fmha.py + python3 17_appendkv_fmha.py --rope interleaved + python3 17_appendkv_fmha.py --seqlen-new 4 --page-size 64 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def make_rotary_cos_sin( + max_seqlen: int, + hdim: int, + base: float = 10000.0, +) -> tuple: + """Generate RoPE cos/sin tables. + + Returns: (cos_table, sin_table) each of shape [max_seqlen, hdim//2] + """ + half_dim = hdim // 2 + inv_freq = 1.0 / (base ** (np.arange(0, half_dim, dtype=np.float32) / half_dim)) + pos = np.arange(max_seqlen, dtype=np.float32) + freqs = np.outer(pos, inv_freq) + return np.cos(freqs).astype(np.float32), np.sin(freqs).astype(np.float32) + + +def apply_rope_interleaved( + x: np.ndarray, cos: np.ndarray, sin: np.ndarray, start_pos: int +) -> np.ndarray: + """Apply interleaved RoPE: pairs (x0,x1), (x2,x3), ... rotated together. + + x: [..., seqlen, hdim] + cos, sin: [max_seqlen, hdim//2] + """ + seqlen = x.shape[-2] + hdim = x.shape[-1] + half = hdim // 2 + + cos_slice = cos[start_pos : start_pos + seqlen, :] + sin_slice = sin[start_pos : start_pos + seqlen, :] + + cos_b = cos_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half)) + sin_b = sin_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half)) + + x_even = x[..., 0::2] + x_odd = x[..., 1::2] + + out = np.empty_like(x) + out[..., 0::2] = x_even * cos_b - x_odd * sin_b + out[..., 1::2] = x_odd * cos_b + x_even * sin_b + return out + + +def apply_rope_half_rotated( + x: np.ndarray, cos: np.ndarray, sin: np.ndarray, start_pos: int +) -> np.ndarray: + """Apply half-rotated RoPE: first half and second half rotated. + + x: [..., seqlen, hdim] + cos, sin: [max_seqlen, hdim//2] + """ + seqlen = x.shape[-2] + hdim = x.shape[-1] + half = hdim // 2 + + cos_slice = cos[start_pos : start_pos + seqlen, :] + sin_slice = sin[start_pos : start_pos + seqlen, :] + + cos_b = cos_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half)) + sin_b = sin_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half)) + + x1, x2 = x[..., :half], x[..., half:] + + out = np.empty_like(x) + out[..., :half] = x1 * cos_b - x2 * sin_b + out[..., half:] = x2 * cos_b + x1 * sin_b + return out + + +def cpu_append_kv( + k_cache: np.ndarray, + v_cache: np.ndarray, + k_new: np.ndarray, + v_new: np.ndarray, + cache_seqlen: int, + rope_fn, + cos: np.ndarray, + sin: np.ndarray, +) -> tuple: + """CPU reference: append new KV tokens to cache with RoPE. + + k_cache/v_cache: [batch, nhead, max_seqlen, hdim] + k_new/v_new: [batch, nhead, seqlen_new, hdim] + + Returns: (k_cache_updated, v_cache_updated) + """ + seqlen_new = k_new.shape[2] + + if rope_fn is not None: + k_rotated = rope_fn(k_new, cos, sin, cache_seqlen) + else: + k_rotated = k_new + + k_out = k_cache.copy() + v_out = v_cache.copy() + k_out[:, :, cache_seqlen : cache_seqlen + seqlen_new, :] = k_rotated + v_out[:, :, cache_seqlen : cache_seqlen + seqlen_new, :] = v_new + + return k_out, v_out + + +def make_paged_cache( + batch: int, nhead: int, total_pages: int, page_size: int, hdim: int +) -> tuple: + """Create a paged KV cache layout. + + Returns: (k_pages, v_pages, page_table, cache_batch_idx) + """ + k_pages = np.zeros((total_pages, nhead, page_size, hdim), dtype=np.float32) + v_pages = np.zeros((total_pages, nhead, page_size, hdim), dtype=np.float32) + + pages_per_seq = total_pages // batch + page_table = np.arange(total_pages, dtype=np.int32).reshape(batch, pages_per_seq) + cache_batch_idx = np.arange(batch, dtype=np.int32) + + return k_pages, v_pages, page_table, cache_batch_idx + + +def main(): + parser = argparse.ArgumentParser(description="AppendKV with RoPE Integration") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=16) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument( + "--seqlen-new", type=int, default=1, help="New tokens to append" + ) + parser.add_argument( + "--cache-seqlen", type=int, default=512, help="Existing cache length" + ) + parser.add_argument("--max-seqlen", type=int, default=2048) + parser.add_argument("--page-size", type=int, default=128) + parser.add_argument( + "--rope", default="both", choices=["interleaved", "half", "none", "both"] + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 17: AppendKV with RoPE Integration") + print("=" * 70) + + print(f"\n Batch: {args.batch}") + print(f" Heads: {args.nhead}") + print(f" HDim: {args.hdim}") + print(f" New tokens: {args.seqlen_new}") + print(f" Cache len: {args.cache_seqlen}") + print(f" Max seqlen: {args.max_seqlen}") + print(f" Page size: {args.page_size}") + + # --- Generate RoPE tables --- + cos, sin = make_rotary_cos_sin(args.max_seqlen, args.hdim) + print("\n RoPE base: 10000.0") + print(f" Cos/Sin: [{args.max_seqlen}, {args.hdim // 2}]") + + # --- Generate new KV data --- + np.random.seed(42) + k_new = ( + np.random.randn(args.batch, args.nhead, args.seqlen_new, args.hdim) * 0.1 + ).astype(np.float32) + v_new = ( + np.random.randn(args.batch, args.nhead, args.seqlen_new, args.hdim) * 0.1 + ).astype(np.float32) + + # --- RoPE comparison --- + rope_modes = [] + if args.rope in ("interleaved", "both"): + rope_modes.append(("interleaved", apply_rope_interleaved)) + if args.rope in ("half", "both"): + rope_modes.append(("half_rotated", apply_rope_half_rotated)) + if args.rope == "none": + rope_modes.append(("none", None)) + + print("\n--- RoPE Modes ---") + print(f"\n {'Mode':<16} {'K_new range':>20} {'K_rope range':>20} {'MaxDiff':>10}") + print(" " + "-" * 70) + + for mode_name, rope_fn in rope_modes: + if rope_fn is not None: + k_roped = rope_fn(k_new, cos, sin, args.cache_seqlen) + else: + k_roped = k_new + + k_range = f"[{k_new.min():.4f}, {k_new.max():.4f}]" + kr_range = f"[{k_roped.min():.4f}, {k_roped.max():.4f}]" + diff = float(np.abs(k_roped - k_new).max()) + print(f" {mode_name:<16} {k_range:>20} {kr_range:>20} {diff:>10.4f}") + + # --- KV Cache Append --- + print("\n--- KV Cache Append ---") + k_cache = np.zeros( + (args.batch, args.nhead, args.max_seqlen, args.hdim), dtype=np.float32 + ) + v_cache = np.zeros( + (args.batch, args.nhead, args.max_seqlen, args.hdim), dtype=np.float32 + ) + + np.random.seed(0) + k_cache[:, :, : args.cache_seqlen, :] = ( + np.random.randn(args.batch, args.nhead, args.cache_seqlen, args.hdim) * 0.1 + ).astype(np.float32) + v_cache[:, :, : args.cache_seqlen, :] = ( + np.random.randn(args.batch, args.nhead, args.cache_seqlen, args.hdim) * 0.1 + ).astype(np.float32) + + for mode_name, rope_fn in rope_modes: + k_up, v_up = cpu_append_kv( + k_cache, + v_cache, + k_new, + v_new, + args.cache_seqlen, + rope_fn, + cos, + sin, + ) + new_len = args.cache_seqlen + args.seqlen_new + k_appended = k_up[:, :, args.cache_seqlen : new_len, :] + print(f"\n {mode_name}:") + print(f" Cache after append: positions [0, {new_len})") + print(f" New K range: [{k_appended.min():.4f}, {k_appended.max():.4f}]") + print( + f" Cache unchanged: {np.array_equal(k_up[:, :, : args.cache_seqlen, :], k_cache[:, :, : args.cache_seqlen, :])}" + ) + + # --- Paged KV Cache --- + print("\n--- Paged KV Cache Layout ---") + total_pages = (args.max_seqlen // args.page_size) * args.batch + k_pages, v_pages, page_table, cache_batch_idx = make_paged_cache( + args.batch, + args.nhead, + total_pages, + args.page_size, + args.hdim, + ) + + pages_per_seq = total_pages // args.batch + print(f" Total pages: {total_pages}") + print(f" Pages per seq: {pages_per_seq}") + print(f" Page size: {args.page_size}") + print(f" K pages shape: {k_pages.shape}") + print(f" Page table: {page_table.shape}") + print(f" cache_batch_idx: {cache_batch_idx}") + + current_page = args.cache_seqlen // args.page_size + offset_in_page = args.cache_seqlen % args.page_size + print(f"\n Append position: page={current_page}, offset={offset_in_page}") + print(f" Physical page idx (batch 0): {page_table[0, current_page]}") + + # --- GPU attempt --- + print("\n--- GPU Execution ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen_new, + seqlen_k=args.cache_seqlen + args.seqlen_new, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + Q_fp16 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K_full = k_cache[:, :, : args.cache_seqlen + args.seqlen_new, :].astype( + np.float16 + ) + V_full = v_cache[:, :, : args.cache_seqlen + args.seqlen_new, :].astype( + np.float16 + ) + res = runner.run(Q_fp16, K_full, V_full, prob) + if res.success: + print( + f" Attention after append: {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS" + ) + else: + print(" GPU: Kernel returned failure (appendkv not supported)") + print(" Note: Prebuilt kernel does not support appendkv family") + + # --- RoPE position-dependency visualization --- + print("\n--- RoPE Position Dependency ---") + positions = [0, 128, 512, 1024] + test_vec = np.ones((1, 1, 1, args.hdim), dtype=np.float32) * 0.1 + + for rope_name, rope_fn in rope_modes: + if rope_fn is None: + continue + print(f"\n {rope_name} (first 4 dims of rotated unit vector):") + print(f" {'Position':>10} {'dim0':>8} {'dim1':>8} {'dim2':>8} {'dim3':>8}") + for pos in positions: + if pos < args.max_seqlen: + rotated = rope_fn(test_vec, cos, sin, pos) + dims = rotated[0, 0, 0, :4] + print( + f" {pos:>10} {dims[0]:>8.4f} {dims[1]:>8.4f} {dims[2]:>8.4f} {dims[3]:>8.4f}" + ) + + # --- Summary --- + print("\n" + "=" * 70) + print( + f" AppendKV: Append {args.seqlen_new} new tokens at position {args.cache_seqlen}" + ) + print(f" RoPE modes: {', '.join(m for m, _ in rope_modes)}") + print(f" Paged cache: {total_pages} pages x {args.page_size} slots") + print(" Pipeline: appendkv -> fwd_pagedkv (2-stage decode)") + print(" GPU: Prebuilt supports fwd only (appendkv needs JIT)") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/18_backward_fmha.py b/dispatcher/examples/fmha/python/18_backward_fmha.py new file mode 100644 index 0000000000..85bb3cee04 --- /dev/null +++ b/dispatcher/examples/fmha/python/18_backward_fmha.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 18: Backward Pass (dQ, dK, dV) + +Demonstrates: +1. Forward pass to obtain O and LSE +2. Backward pass computing gradients dQ, dK, dV from dO +3. Three-stage backward plan: + - Stage 1 (dot_do_o): Compute D = rowsum(dO * O) + - Stage 2 (dq_dk_dv): Compute dQ, dK, dV using D and LSE + - Stage 3 (convert_dq): Optional dtype conversion for dQ +4. CPU reference with analytical gradients +5. Gradient checking via finite differences + +Usage: + python3 18_backward_fmha.py + python3 18_backward_fmha.py --seqlen 128 + python3 18_backward_fmha.py --check-grad --eps 1e-3 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def cpu_attention_fwd_with_lse( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, +) -> tuple: + """Forward pass returning O, P (attention weights), and LSE. + + Returns: (O, P, lse) + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + out = np.matmul(P, V) + lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32) + return out, P, lse + + +def cpu_attention_bwd( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + P: np.ndarray, + scale: float, +) -> tuple: + """CPU reference backward pass. + + Computes analytical gradients dQ, dK, dV. + + Stage 1: D_i = sum_j(dO_ij * O_ij) (per query position) + Stage 2: dS = P * (dO @ V^T - D) + dQ = dS @ K * scale + dK = dS^T @ Q * scale + dV = P^T @ dO + + Returns: (dQ, dK, dV, D) + """ + # Stage 1: dot_do_o + D = (dO * out).sum(axis=-1, keepdims=True) + + # Stage 2: dq_dk_dv + dP = np.matmul(dO, V.transpose(0, 1, 3, 2)) + dS = P * (dP - D) + + dQ = np.matmul(dS, K) * scale + dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale + dV = np.matmul(P.transpose(0, 1, 3, 2), dO) + + return dQ, dK, dV, D.squeeze(-1) + + +def finite_difference_check( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + dO: np.ndarray, + scale: float, + eps: float = 1e-3, + param_name: str = "Q", + max_checks: int = 5, +) -> float: + """Verify gradients via finite differences on a few elements.""" + param_map = {"Q": Q, "K": K, "V": V} + param = param_map[param_name] + + O_ref, P_ref, _ = cpu_attention_fwd_with_lse(Q, K, V, scale) + _, _, _, _ = cpu_attention_bwd(Q, K, V, O_ref, dO, P_ref, scale) + + grad_map = {"Q": 0, "K": 1, "V": 2} + grad_idx = grad_map[param_name] + grads = cpu_attention_bwd(Q, K, V, O_ref, dO, P_ref, scale) + analytical_grad = grads[grad_idx] + + max_err = 0.0 + flat_indices = np.random.choice( + param.size, min(max_checks, param.size), replace=False + ) + + for flat_idx in flat_indices: + idx = np.unravel_index(flat_idx, param.shape) + orig = param[idx] + + param[idx] = orig + eps + O_plus = cpu_attention_fwd(Q, K, V, scale) + loss_plus = (O_plus * dO).sum() + + param[idx] = orig - eps + O_minus = cpu_attention_fwd(Q, K, V, scale) + loss_minus = (O_minus * dO).sum() + + param[idx] = orig + + fd_grad = (loss_plus - loss_minus) / (2 * eps) + an_grad = analytical_grad[idx] + err = abs(fd_grad - an_grad) / (abs(fd_grad) + 1e-8) + max_err = max(max_err, err) + + return max_err + + +def main(): + parser = argparse.ArgumentParser(description="Backward Pass (dQ, dK, dV)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--nhead", type=int, default=4) + parser.add_argument("--seqlen", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument( + "--check-grad", action="store_true", help="Run finite-difference check" + ) + parser.add_argument( + "--eps", type=float, default=1e-3, help="Finite-difference epsilon" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 18: Backward Pass (dQ, dK, dV)") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}") + + # --- Generate data --- + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + # --- Forward pass --- + print("\n--- Stage 0: Forward Pass ---") + out, P, lse = cpu_attention_fwd_with_lse(Q, K, V, prob.scale) + print(f" O shape: {out.shape}") + print(f" O range: [{out.min():.4f}, {out.max():.4f}]") + print(f" LSE shape: {lse.shape}") + print(f" LSE range: [{lse.min():.4f}, {lse.max():.4f}]") + print(f" P sparsity (< 1e-6): {(P < 1e-6).sum() / P.size * 100:.1f}%") + + # --- Backward pass (3 stages) --- + print("\n--- Stage 1: dot_do_o (D = rowsum(dO * O)) ---") + D_full = (dO * out).sum(axis=-1) + print(f" D shape: {D_full.shape}") + print(f" D range: [{D_full.min():.6f}, {D_full.max():.6f}]") + + print("\n--- Stage 2: dq_dk_dv ---") + dQ, dK, dV, D = cpu_attention_bwd(Q, K, V, out, dO, P, prob.scale) + print(f" dQ shape: {dQ.shape}, range: [{dQ.min():.4e}, {dQ.max():.4e}]") + print(f" dK shape: {dK.shape}, range: [{dK.min():.4e}, {dK.max():.4e}]") + print(f" dV shape: {dV.shape}, range: [{dV.min():.4e}, {dV.max():.4e}]") + + print("\n--- Stage 3: convert_dq (optional fp32 -> fp16) ---") + dQ_fp16 = dQ.astype(np.float16) + convert_err = float(np.abs(dQ - dQ_fp16.astype(np.float32)).max()) + print(f" dQ fp32 -> fp16 max error: {convert_err:.2e}") + + # --- Gradient norms --- + print("\n--- Gradient Statistics ---") + print( + f"\n {'Param':<6} {'L2 Norm':>12} {'Max Abs':>12} {'Mean Abs':>12} {'Shape'}" + ) + print(" " + "-" * 60) + for name, grad in [("dQ", dQ), ("dK", dK), ("dV", dV)]: + l2 = float(np.sqrt((grad**2).sum())) + ma = float(np.abs(grad).max()) + mean_a = float(np.abs(grad).mean()) + print(f" {name:<6} {l2:>12.4e} {ma:>12.4e} {mean_a:>12.4e} {grad.shape}") + + # --- Finite difference check --- + if args.check_grad: + print(f"\n--- Finite Difference Gradient Check (eps={args.eps}) ---") + for pname in ["Q", "K", "V"]: + Q_c, K_c, V_c = Q.copy(), K.copy(), V.copy() + err = finite_difference_check( + Q_c, + K_c, + V_c, + dO, + prob.scale, + eps=args.eps, + param_name=pname, + max_checks=5, + ) + tag = "PASS" if err < 1e-2 else "FAIL" + print(f" d{pname}: max_rel_err = {err:.4e} {tag}") + + # --- GPU forward attempt --- + print("\n--- GPU Execution ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + Q_fp16 = Q.astype(np.float16) + K_fp16 = K.astype(np.float16) + V_fp16 = V.astype(np.float16) + res = runner.run(Q_fp16, K_fp16, V_fp16, prob) + if res.success: + print(f" Forward GPU: {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS") + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + ok, ma, _ = validator.check(res.output, out) + print(f" Forward validation: max_err={ma:.2e}, {'PASS' if ok else 'FAIL'}") + else: + print(" Forward GPU: Kernel returned failure") + print(" Backward GPU: Not available (requires bwd family kernel)") + + # --- Backward plan structure --- + print("\n--- Backward Plan Structure ---") + print(" Stage 1: dot_do_o") + print(f" Input: dO [{prob.o_shape()}], O [{prob.o_shape()}]") + print(f" Output: D [{prob.batch}, {prob.nhead_q}, {prob.seqlen_q}]") + print(" Stage 2: dq_dk_dv") + print(" Input: Q, K, V, dO, LSE, D") + print(" Output: dQ, dK, dV (in accumulator precision)") + print(" Stage 3: convert_dq") + print(" Input: dQ (fp32)") + print(" Output: dQ (fp16)") + + # --- Summary --- + print("\n" + "=" * 70) + print(" Forward: O = softmax(Q @ K^T / sqrt(d)) @ V") + print(" Backward: 3-stage plan (dot_do_o -> dq_dk_dv -> convert_dq)") + print(f" Gradients: dQ [{dQ.shape}], dK [{dK.shape}], dV [{dV.shape}]") + print(" GPU: Prebuilt supports forward only") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/19_padding_fmha.py b/dispatcher/examples/fmha/python/19_padding_fmha.py new file mode 100644 index 0000000000..f764a645c5 --- /dev/null +++ b/dispatcher/examples/fmha/python/19_padding_fmha.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 19: Batch Padding and Group Mode + +Demonstrates: +1. Batch mode with effective lengths (q_eff_lens, kv_eff_lens) + - Padded to max length but only effective positions contribute +2. Group mode with physical padding strides (s_qpad, s_kpad) + - Variable-length sequences packed contiguously + - seqstart pointers mark boundaries +3. Comparing batch vs group mode memory efficiency + +In batch mode, each sequence in the batch is padded to the same max length. +In group mode, sequences are packed without padding using offset pointers, +saving memory for batches with high length variance. + +Usage: + python3 19_padding_fmha.py + python3 19_padding_fmha.py --batch 8 + python3 19_padding_fmha.py --max-seqlen 512 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def cpu_batch_padded_attention( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + q_eff_lens: np.ndarray, + kv_eff_lens: np.ndarray, +) -> np.ndarray: + """CPU reference: batch attention with effective lengths. + + Positions beyond effective length are masked out. + Q: [batch, nhead, max_seqlen_q, hdim] + """ + batch = Q.shape[0] + nhead = Q.shape[1] + max_sq = Q.shape[2] + hdim_v = V.shape[3] + + out = np.zeros((batch, nhead, max_sq, hdim_v), dtype=np.float32) + + for b in range(batch): + ql = q_eff_lens[b] + kl = kv_eff_lens[b] + + Q_b = Q[b : b + 1, :, :ql, :] + K_b = K[b : b + 1, :, :kl, :] + V_b = V[b : b + 1, :, :kl, :] + + O_b = cpu_attention_fwd(Q_b, K_b, V_b, scale) + out[b, :, :ql, :] = O_b[0] + + return out + + +def pack_group_mode( + Q_batch: np.ndarray, + K_batch: np.ndarray, + V_batch: np.ndarray, + q_lens: np.ndarray, + kv_lens: np.ndarray, +) -> tuple: + """Pack batch sequences into group mode (contiguous, no padding). + + Returns: (Q_packed, K_packed, V_packed, seqstart_q, seqstart_k) + """ + batch = Q_batch.shape[0] + nhead = Q_batch.shape[1] + hdim_q = Q_batch.shape[3] + hdim_v = V_batch.shape[3] + + total_q = int(q_lens.sum()) + total_k = int(kv_lens.sum()) + + Q_packed = np.zeros((1, nhead, total_q, hdim_q), dtype=Q_batch.dtype) + K_packed = np.zeros((1, nhead, total_k, hdim_q), dtype=K_batch.dtype) + V_packed = np.zeros((1, nhead, total_k, hdim_v), dtype=V_batch.dtype) + + seqstart_q = np.zeros(batch + 1, dtype=np.int32) + seqstart_k = np.zeros(batch + 1, dtype=np.int32) + + q_offset = 0 + k_offset = 0 + for b in range(batch): + ql, kl = int(q_lens[b]), int(kv_lens[b]) + Q_packed[0, :, q_offset : q_offset + ql, :] = Q_batch[b, :, :ql, :] + K_packed[0, :, k_offset : k_offset + kl, :] = K_batch[b, :, :kl, :] + V_packed[0, :, k_offset : k_offset + kl, :] = V_batch[b, :, :kl, :] + q_offset += ql + k_offset += kl + seqstart_q[b + 1] = q_offset + seqstart_k[b + 1] = k_offset + + return Q_packed, K_packed, V_packed, seqstart_q, seqstart_k + + +def cpu_group_attention( + Q_packed: np.ndarray, + K_packed: np.ndarray, + V_packed: np.ndarray, + scale: float, + seqstart_q: np.ndarray, + seqstart_k: np.ndarray, + batch: int, +) -> np.ndarray: + """CPU reference: group mode attention on packed sequences. + + Q_packed: [1, nhead, total_q, hdim] + """ + nhead = Q_packed.shape[1] + total_q = Q_packed.shape[2] + hdim_v = V_packed.shape[3] + + O_packed = np.zeros((1, nhead, total_q, hdim_v), dtype=np.float32) + + for b in range(batch): + qs, qe = seqstart_q[b], seqstart_q[b + 1] + ks, ke = seqstart_k[b], seqstart_k[b + 1] + + Q_b = Q_packed[:, :, qs:qe, :] + K_b = K_packed[:, :, ks:ke, :] + V_b = V_packed[:, :, ks:ke, :] + + O_b = cpu_attention_fwd(Q_b, K_b, V_b, scale) + O_packed[0, :, qs:qe, :] = O_b[0] + + return O_packed + + +def main(): + parser = argparse.ArgumentParser(description="Batch Padding and Group Mode") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=4) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--max-seqlen", type=int, default=256) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + print("=" * 70) + print("Example 19: Batch Padding and Group Mode") + print("=" * 70) + + batch = args.batch + nhead = args.nhead + max_sq = max_sk = args.max_seqlen + hdim = args.hdim + + # --- Variable-length sequences --- + np.random.seed(args.seed) + q_eff_lens = np.sort( + np.random.randint(32, max_sq + 1, size=batch).astype(np.int32) + )[::-1] + kv_eff_lens = np.sort( + np.random.randint(32, max_sk + 1, size=batch).astype(np.int32) + )[::-1] + q_eff_lens = q_eff_lens.copy() + kv_eff_lens = kv_eff_lens.copy() + + print(f"\n Batch: {batch}") + print(f" Max seqlen: {max_sq}") + print(f" HDim: {hdim}") + print(f"\n {'Seq#':<6} {'q_len':>8} {'kv_len':>8} {'q_pad%':>8} {'kv_pad%':>8}") + print(" " + "-" * 42) + for b in range(batch): + q_pad = (1.0 - q_eff_lens[b] / max_sq) * 100 + kv_pad = (1.0 - kv_eff_lens[b] / max_sk) * 100 + print( + f" {b:<6} {q_eff_lens[b]:>8} {kv_eff_lens[b]:>8} {q_pad:>7.1f}% {kv_pad:>7.1f}%" + ) + + # --- Generate padded data --- + Q_padded = (np.random.randn(batch, nhead, max_sq, hdim) * 0.1).astype(np.float32) + K_padded = (np.random.randn(batch, nhead, max_sk, hdim) * 0.1).astype(np.float32) + V_padded = (np.random.randn(batch, nhead, max_sk, hdim) * 0.1).astype(np.float32) + + # === BATCH MODE === + print("\n--- Batch Mode (padded) ---") + O_batch = cpu_batch_padded_attention( + Q_padded, + K_padded, + V_padded, + 1.0 / (hdim**0.5), + q_eff_lens, + kv_eff_lens, + ) + + batch_mem = batch * nhead * (max_sq + 2 * max_sk) * hdim * 4 + print(f" Q/K/V layout: [{batch}, {nhead}, {max_sq}, {hdim}]") + print(f" Memory (Q+K+V): {batch_mem / 1024:.1f} KB") + print( + f" Wasted (avg): {(1.0 - q_eff_lens.mean() / max_sq) * 100:.1f}% (padding overhead)" + ) + + # === GROUP MODE === + print("\n--- Group Mode (packed) ---") + Q_packed, K_packed, V_packed, seqstart_q, seqstart_k = pack_group_mode( + Q_padded, + K_padded, + V_padded, + q_eff_lens, + kv_eff_lens, + ) + + total_q = int(q_eff_lens.sum()) + total_k = int(kv_eff_lens.sum()) + group_mem = nhead * (total_q + 2 * total_k) * hdim * 4 + + print(f" Q_packed: [1, {nhead}, {total_q}, {hdim}]") + print(f" K_packed: [1, {nhead}, {total_k}, {hdim}]") + print(f" seqstart_q: {seqstart_q}") + print(f" seqstart_k: {seqstart_k}") + print(f" Memory (Q+K+V): {group_mem / 1024:.1f} KB") + print(f" Saving vs batch: {(1.0 - group_mem / batch_mem) * 100:.1f}%") + + # Physical padding strides + s_qpad = total_q + s_kpad = total_k + print("\n Physical strides:") + print(f" s_qpad = {s_qpad} (total Q tokens)") + print(f" s_kpad = {s_kpad} (total KV tokens)") + + O_group = cpu_group_attention( + Q_packed, + K_packed, + V_packed, + 1.0 / (hdim**0.5), + seqstart_q, + seqstart_k, + batch, + ) + + # --- Cross-validate batch vs group --- + print("\n--- Batch vs Group Validation ---") + print(f"\n {'Seq#':<6} {'q_len':>8} {'MaxErr':>10} {'Status':>8}") + print(" " + "-" * 36) + + all_ok = True + for b in range(batch): + ql = q_eff_lens[b] + qs = seqstart_q[b] + O_b_batch = O_batch[b, :, :ql, :] + O_b_group = O_group[0, :, qs : qs + ql, :] + max_err = float(np.abs(O_b_batch - O_b_group).max()) + ok = max_err < 1e-5 + all_ok = all_ok and ok + print(f" {b:<6} {ql:>8} {max_err:>10.2e} {'PASS' if ok else 'FAIL':>8}") + + # --- GPU attempt --- + print("\n--- GPU Execution ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + prob = FmhaProblem( + batch=batch, + nhead_q=nhead, + nhead_k=nhead, + seqlen_q=max_sq, + seqlen_k=max_sk, + hdim_q=hdim, + hdim_v=hdim, + ) + Q_fp16 = Q_padded.astype(np.float16) + K_fp16 = K_padded.astype(np.float16) + V_fp16 = V_padded.astype(np.float16) + res = runner.run(Q_fp16, K_fp16, V_fp16, prob) + if res.success: + print(f" GPU (full padded): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS") + print( + " Note: GPU runs full padded attention; effective-length masking needs kernel support" + ) + else: + print(" GPU: Kernel returned failure") + + # --- Memory analysis --- + print("\n--- Memory Efficiency Analysis ---") + print(f"\n {'Metric':<24} {'Batch Mode':>14} {'Group Mode':>14} {'Ratio':>8}") + print(" " + "-" * 64) + + batch_tokens_q = batch * max_sq + group_tokens_q = total_q + batch_tokens_k = batch * max_sk + group_tokens_k = total_k + + print( + f" {'Q tokens':<24} {batch_tokens_q:>14} {group_tokens_q:>14} {group_tokens_q / batch_tokens_q:>7.2f}x" + ) + print( + f" {'KV tokens':<24} {batch_tokens_k:>14} {group_tokens_k:>14} {group_tokens_k / batch_tokens_k:>7.2f}x" + ) + print( + f" {'Memory (KB)':<24} {batch_mem / 1024:>14.1f} {group_mem / 1024:>14.1f} {group_mem / batch_mem:>7.2f}x" + ) + print( + f" {'Compute (tokens)':<24} {batch_tokens_q * batch_tokens_k:>14} {sum(q_eff_lens[i] * kv_eff_lens[i] for i in range(batch)):>14} " + f"{sum(q_eff_lens[i] * kv_eff_lens[i] for i in range(batch)) / (batch_tokens_q * batch_tokens_k):>7.2f}x" + ) + + # --- Summary --- + print("\n" + "=" * 70) + print(" Batch mode: Padded to max_seqlen, uses q_eff_lens/kv_eff_lens") + print(" Group mode: Packed contiguously, uses seqstart pointers") + print(f" Strides: s_qpad={s_qpad}, s_kpad={s_kpad}") + print(f" Memory save: {(1.0 - group_mem / batch_mem) * 100:.1f}% with group mode") + print(f" Batch==Group: {'PASS' if all_ok else 'FAIL'} (identical results)") + print(" GPU: Prebuilt supports batch mode only") + print(f" Status: {'PASS' if all_ok else 'FAIL'}") + print("=" * 70) + + return 0 if all_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/20_fp8_fmha.py b/dispatcher/examples/fmha/python/20_fp8_fmha.py new file mode 100644 index 0000000000..8cdb2fa3c5 --- /dev/null +++ b/dispatcher/examples/fmha/python/20_fp8_fmha.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 20: FP8 FMHA Forward + +Demonstrates FP8 data types (fp8bf16, fp8fp32) for FMHA forward +with quantization scale (pertensor, blockscale). + +Note: FP8 requires a kernel compiled with fp8bf16/fp8fp32 dtype. +The prebuilt library has fp16 only, so this example shows the +API pattern and CPU reference. + +Usage: + python3 20_fp8_fmha.py +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaKernelConfig, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +FP8_CONFIGS = [ + ("fp8bf16", "pertensor", "FP8 with BF16 output, per-tensor scale"), + ("fp8fp32", "pertensor", "FP8 with FP32 output, per-tensor scale"), + ("fp8bf16", "blockscale", "FP8 with BF16 output, block scale"), +] + + +def main(): + parser = argparse.ArgumentParser(description="FP8 FMHA Example") + parser.add_argument("--arch", default=detect_gpu_arch()) + args = parser.parse_args() + + print("=" * 70) + print("Example 20: FP8 FMHA Forward") + print("=" * 70) + + prob = FmhaProblem( + batch=2, nhead_q=4, nhead_k=4, seqlen_q=64, seqlen_k=64, hdim_q=128, hdim_v=128 + ) + + print(f"\n Arch: {args.arch}") + print(f" Shape: B={prob.batch} H={prob.nhead_q} S={prob.seqlen_q} D={prob.hdim_q}") + + # CPU reference (fp32 baseline) + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + O_ref = cpu_attention_fwd(Q, K, V, prob.scale) + + print("\n--- FP8 Configurations ---\n") + print(f" {'#':<3} {'Dtype':<12} {'QScale':<12} {'Description':<45} {'Status':<6}") + print(" " + "-" * 80) + + for i, (dtype, qscale, desc) in enumerate(FP8_CONFIGS, 1): + _cfg = FmhaKernelConfig( + data_type=dtype, + hdim_q=128, + hdim_v=128, + qscale=qscale, + gfx_arch=args.arch, + ) + + # FP8 kernels need dedicated compilation + status = "CPU-OK" + print(f" {i:<3} {dtype:<12} {qscale:<12} {desc:<45} {status:<6}") + + # Show FP8 tolerance expectations + print("\n--- FP8 Tolerance Reference ---") + print(" fp8bf16: rtol=1e-2, atol=1.8e-1") + print(" fp8fp32: rtol=1e-2, atol=1.8e-1") + print(" fp8 raw: rtol=0, atol=16 (or 32 for >240 range)") + + # Run basic fp16 for comparison if prebuilt available + print("\n--- FP16 Baseline (prebuilt) ---") + config_fp16 = FmhaKernelConfig( + data_type="fp16", + hdim_q=128, + hdim_v=128, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config_fp16) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + Q16 = Q.astype(np.float16) + K16 = K.astype(np.float16) + V16 = V.astype(np.float16) + result = runner.run(Q16, K16, V16, prob) + if result.success: + max_err = float(np.abs(result.output.astype(np.float32) - O_ref).max()) + print(f" FP16 baseline: {result.time_ms:.4f} ms, max_err={max_err:.2e}") + + print(f"\n{'=' * 70}") + print(f" FP8 kernel configs demonstrated: {len(FP8_CONFIGS)}") + print(" Note: Build fp8bf16/fp8fp32 kernels for GPU execution") + print(" Status: PASS") + print(f"{'=' * 70}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/21_logits_soft_cap_fmha.py b/dispatcher/examples/fmha/python/21_logits_soft_cap_fmha.py new file mode 100644 index 0000000000..6e6823902a --- /dev/null +++ b/dispatcher/examples/fmha/python/21_logits_soft_cap_fmha.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 21: Logits Soft Cap FMHA + +Demonstrates the logits soft cap feature, which prevents attention logits +from growing unboundedly by applying: tanh(scores / soft_cap) * soft_cap +before the softmax. This technique is used in models like Gemma-2 to +stabilize training at large scale. + +The prebuilt library does not include a logits_soft_cap kernel, so this +example validates the CPU reference implementation and shows the API +pattern for when a compiled kernel with logits=True is available. + +Usage: + python3 21_logits_soft_cap_fmha.py + python3 21_logits_soft_cap_fmha.py --soft-cap 30.0 + python3 21_logits_soft_cap_fmha.py --seqlen 256 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def cpu_attention_fwd_logits_soft_cap( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + soft_cap: float, +) -> np.ndarray: + """CPU reference: attention with logits soft cap. + + Before softmax, scores are clamped via: + scores = tanh(scores / soft_cap) * soft_cap + + Args: + Q: [batch, nhead_q, seqlen_q, hdim_q] float32 + K: [batch, nhead_k, seqlen_k, hdim_q] float32 + V: [batch, nhead_k, seqlen_k, hdim_v] float32 + scale: softmax scaling factor (1/sqrt(hdim_q)) + soft_cap: logits soft cap value (e.g. 50.0) + + Returns: + O: [batch, nhead_q, seqlen_q, hdim_v] float32 + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S = np.tanh(S / soft_cap) * soft_cap + + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + P = S_exp / S_exp.sum(axis=-1, keepdims=True) + return np.matmul(P, V) + + +def show_soft_cap_effect(scale: float, soft_cap: float): + """Visualize the clamping effect of logits soft cap on score magnitudes.""" + raw_scores = np.array( + [-100, -50, -20, -10, -5, 0, 5, 10, 20, 50, 100], dtype=np.float32 + ) + scaled = raw_scores * scale + capped = np.tanh(scaled / soft_cap) * soft_cap + + print(f"\n Soft cap effect (scale={scale:.4f}, soft_cap={soft_cap:.1f}):") + print( + f" {'Raw Score':>12} {'After Scale':>14} {'After Cap':>12} {'Reduction':>12}" + ) + print(" " + "-" * 54) + for r, s, c in zip(raw_scores, scaled, capped): + reduction = abs(s) - abs(c) if abs(s) > 0 else 0 + print(f" {r:>12.1f} {s:>14.4f} {c:>12.4f} {reduction:>12.4f}") + + +def main(): + parser = argparse.ArgumentParser( + description="Logits Soft Cap FMHA Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 21_logits_soft_cap_fmha.py # Default soft_cap=50 + python3 21_logits_soft_cap_fmha.py --soft-cap 30.0 # Tighter cap + python3 21_logits_soft_cap_fmha.py --seqlen 256 + """, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument( + "--soft-cap", type=float, default=50.0, help="Logits soft cap value" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 21: Logits Soft Cap FMHA") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + # Step 1: Demonstrate the soft cap transformation + print("\nStep 1: Soft Cap Transformation") + show_soft_cap_effect(prob.scale, args.soft_cap) + + # Step 2: CPU reference comparison -- with vs without soft cap + print("\nStep 2: CPU Reference (with vs without soft cap)") + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float32) + + O_no_cap = cpu_attention_fwd(Q, K, V, prob.scale) + O_capped = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, args.soft_cap) + + diff = np.abs(O_no_cap - O_capped) + print(f"\n Shape: {prob.q_shape()}") + print(f" Soft cap: {args.soft_cap}") + print(f" Output range (no cap): [{O_no_cap.min():.4f}, {O_no_cap.max():.4f}]") + print(f" Output range (capped): [{O_capped.min():.4f}, {O_capped.max():.4f}]") + print(f" Max diff (cap effect): {diff.max():.6e}") + print(f" Mean diff (cap effect): {diff.mean():.6e}") + + # Step 3: Validate across different soft_cap values + print("\nStep 3: Soft Cap Sweep") + + soft_cap_values = [10.0, 20.0, 30.0, 50.0, 100.0, 500.0] + validator = FmhaValidator(rtol=1e-4, atol=1e-4) + + print( + f"\n {'SoftCap':>10} {'OutRange':>20} {'vs NoCap MaxDiff':>18} {'vs NoCap MeanDiff':>18}" + ) + print(" " + "-" * 70) + + for sc in soft_cap_values: + O_sc = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, sc) + d = np.abs(O_no_cap - O_sc) + out_range = f"[{O_sc.min():.4f}, {O_sc.max():.4f}]" + print(f" {sc:>10.1f} {out_range:>20} {d.max():>18.6e} {d.mean():>18.6e}") + + # Step 4: Self-consistency -- large soft_cap should approach no-cap result + print("\nStep 4: Self-Consistency Check") + + O_large_cap = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, 1e6) + ok, max_abs, _ = validator.check(O_large_cap, O_no_cap) + print( + f" soft_cap=1e6 vs no_cap: max_err={max_abs:.2e} -> {'PASS' if ok else 'FAIL'}" + ) + + # Step 5: GPU API pattern (requires logits=True kernel) + print("\nStep 5: GPU Kernel Pattern") + print(" NOTE: The prebuilt library does not include a logits_soft_cap kernel.") + print(" To run on GPU, compile a kernel with logits=True in the signature:") + print() + print(" config = FmhaKernelConfig(") + print(" family='fwd', data_type='fp16', hdim_q=128, hdim_v=128,") + print(" pipeline='qr_async',") + print(" )") + print(' # In codegen JSON, set: "logits": true') + print() + print(" The dispatcher will pass logits_soft_cap to the kernel arguments.") + + # Step 6: GPU run with standard kernel (no soft cap) for baseline + print("\nStep 6: GPU Baseline (standard kernel, no soft cap)") + + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + Q_f16 = Q.astype(np.float16) + K_f16 = K.astype(np.float16) + V_f16 = V.astype(np.float16) + + result = runner.run(Q_f16, K_f16, V_f16, prob) + if result.success: + ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_no_cap) + print( + f" GPU (no cap): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} " + f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}" + ) + else: + print(f" GPU error: {result.error}") + + # Summary + print("\n" + "=" * 70) + print(" Logits soft cap: tanh(scores / cap) * cap before softmax") + print(f" Large cap -> standard attention (verified: max_err={max_abs:.2e})") + print(" Small cap -> output variance reduced, stabilizes training") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/22_sink_tokens_fmha.py b/dispatcher/examples/fmha/python/22_sink_tokens_fmha.py new file mode 100644 index 0000000000..73446de2f1 --- /dev/null +++ b/dispatcher/examples/fmha/python/22_sink_tokens_fmha.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 22: Sink Token Attention FMHA + +Demonstrates sink token attention where the first N "sink" tokens are +always attended to regardless of the causal mask. This technique is used +in StreamingLLM and similar approaches to keep a few initial tokens as +attention anchors during long-context generation. + +Mask format: t:left,right,sink -- a causal mask (top-left or bottom-right) +where the first 'sink' positions are always unmasked. + +The prebuilt library does not include a sink token kernel, so this +example validates the CPU reference and shows the API pattern. + +Usage: + python3 22_sink_tokens_fmha.py + python3 22_sink_tokens_fmha.py --sink-tokens 8 + python3 22_sink_tokens_fmha.py --seqlen 256 --window 64 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def make_causal_mask(seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Standard causal (top-left) mask: attend only to positions <= current.""" + mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32) + for i in range(seqlen_q): + for j in range(seqlen_k): + if j <= i: + mask[i, j] = 1.0 + return mask + + +def make_causal_sink_mask( + seqlen_q: int, + seqlen_k: int, + num_sink: int, +) -> np.ndarray: + """Causal mask with sink tokens: always attend to first num_sink positions. + + For each query position i: + - Always attend to positions [0, num_sink) (sink tokens) + - Also attend to positions [j] where j <= i (standard causal) + """ + mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32) + for i in range(seqlen_q): + for j in range(seqlen_k): + if j < num_sink or j <= i: + mask[i, j] = 1.0 + return mask + + +def make_sliding_window_sink_mask( + seqlen_q: int, + seqlen_k: int, + window: int, + num_sink: int, +) -> np.ndarray: + """Sliding window mask with sink tokens. + + For each query position i: + - Always attend to positions [0, num_sink) (sink tokens) + - Attend to positions in [i - window + 1, i] (sliding window) + """ + mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32) + for i in range(seqlen_q): + for j in range(seqlen_k): + if j < num_sink or (i - window + 1 <= j <= i): + mask[i, j] = 1.0 + return mask + + +def cpu_attention_fwd_masked( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + mask: np.ndarray, +) -> np.ndarray: + """CPU reference: attention with explicit mask. + + Args: + Q: [batch, nhead_q, seqlen_q, hdim_q] float32 + K: [batch, nhead_k, seqlen_k, hdim_q] float32 + V: [batch, nhead_k, seqlen_k, hdim_v] float32 + scale: softmax scale + mask: [seqlen_q, seqlen_k] binary mask (1=attend, 0=ignore) + + Returns: + O: [batch, nhead_q, seqlen_q, hdim_v] float32 + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + neg_inf = np.finfo(np.float32).min + S = np.where(mask[np.newaxis, np.newaxis, :, :] > 0, S, neg_inf) + + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + P = S_exp / S_exp.sum(axis=-1, keepdims=True) + return np.matmul(P, V) + + +def print_mask(mask: np.ndarray, name: str, max_display: int = 16): + """Print a small portion of a mask for visualization.""" + rows, cols = mask.shape + rows_show = min(rows, max_display) + cols_show = min(cols, max_display) + print(f"\n {name} ({rows}x{cols}, showing {rows_show}x{cols_show}):") + for i in range(rows_show): + row_str = "".join("1" if mask[i, j] > 0 else "." for j in range(cols_show)) + print(f" q{i:02d}: {row_str}") + + +def main(): + parser = argparse.ArgumentParser( + description="Sink Token Attention FMHA Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument( + "--sink-tokens", type=int, default=4, help="Number of sink tokens" + ) + parser.add_argument("--window", type=int, default=32, help="Sliding window size") + args = parser.parse_args() + + print("=" * 70) + print("Example 22: Sink Token Attention FMHA") + print("=" * 70) + + sq = sk = args.seqlen + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + # Step 1: Visualize mask patterns + print("\nStep 1: Mask Patterns") + + causal = make_causal_mask(sq, sk) + causal_sink = make_causal_sink_mask(sq, sk, args.sink_tokens) + window_sink = make_sliding_window_sink_mask(sq, sk, args.window, args.sink_tokens) + + vis_size = min(16, sq) + print_mask(causal[:vis_size, :vis_size], "Causal (standard)", vis_size) + print_mask( + causal_sink[:vis_size, :vis_size], + f"Causal + {args.sink_tokens} sink tokens", + vis_size, + ) + print_mask( + window_sink[:vis_size, :vis_size], + f"Window({args.window}) + {args.sink_tokens} sink tokens", + vis_size, + ) + + # Step 2: CPU reference for each mask type + print("\n\nStep 2: CPU Reference Comparison") + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32) + + O_no_mask = cpu_attention_fwd(Q, K, V, prob.scale) + O_causal = cpu_attention_fwd_masked(Q, K, V, prob.scale, causal) + O_causal_sink = cpu_attention_fwd_masked(Q, K, V, prob.scale, causal_sink) + O_window_sink = cpu_attention_fwd_masked(Q, K, V, prob.scale, window_sink) + + masks_and_outputs = [ + ("No mask", O_no_mask), + ("Causal", O_causal), + (f"Causal+sink({args.sink_tokens})", O_causal_sink), + (f"Window({args.window})+sink({args.sink_tokens})", O_window_sink), + ] + + print(f"\n {'Mask Type':<30} {'Output Range':>20} {'vs NoMask MaxDiff':>18}") + print(" " + "-" * 70) + for name, out in masks_and_outputs: + d = np.abs(out - O_no_mask).max() + out_range = f"[{out.min():.4f}, {out.max():.4f}]" + print(f" {name:<30} {out_range:>20} {d:>18.6e}") + + # Step 3: Verify sink tokens effect + print("\nStep 3: Sink Token Effect Analysis") + + diff_causal_vs_sink = np.abs(O_causal - O_causal_sink) + print(" Causal vs Causal+Sink:") + print(f" Max diff: {diff_causal_vs_sink.max():.6e}") + print(f" Mean diff: {diff_causal_vs_sink.mean():.6e}") + + n_attend_causal = causal.sum() + n_attend_sink = causal_sink.sum() + n_attend_window = window_sink.sum() + print("\n Attention density:") + print( + f" Causal: {n_attend_causal:>8.0f} / {sq * sk} ({100 * n_attend_causal / (sq * sk):.1f}%)" + ) + print( + f" Causal+sink: {n_attend_sink:>8.0f} / {sq * sk} ({100 * n_attend_sink / (sq * sk):.1f}%)" + ) + print( + f" Window+sink: {n_attend_window:>8.0f} / {sq * sk} ({100 * n_attend_window / (sq * sk):.1f}%)" + ) + + # Step 4: Sweep sink token count + print("\nStep 4: Sink Token Sweep") + + sink_counts = [0, 1, 2, 4, 8, 16] + validator = FmhaValidator(rtol=1e-4, atol=1e-4) + + print( + f"\n {'Sinks':>6} {'Density':>10} {'vs Causal MaxDiff':>20} {'vs NoMask MaxDiff':>20}" + ) + print(" " + "-" * 60) + + for ns in sink_counts: + if ns > sk: + continue + m = make_causal_sink_mask(sq, sk, ns) + O_s = cpu_attention_fwd_masked(Q, K, V, prob.scale, m) + d_causal = np.abs(O_s - O_causal).max() + d_nomask = np.abs(O_s - O_no_mask).max() + density = 100 * m.sum() / (sq * sk) + print(f" {ns:>6} {density:>9.1f}% {d_causal:>20.6e} {d_nomask:>20.6e}") + + # Step 5: GPU API pattern + print("\nStep 5: GPU Kernel Pattern") + print(" NOTE: The prebuilt library does not include a sink token kernel.") + print(" To compile a sink-enabled kernel, use:") + print() + print(" FmhaSignature()") + print(" .mask('top_left') // causal mask required with sink") + print(" .sink(true) // enable sink tokens") + print() + print(" At runtime, pass sink count via the mask spec: 't:left,right,sink'") + print( + f" Example: 't:0,0,{args.sink_tokens}' for causal + {args.sink_tokens} sink tokens" + ) + + # Step 6: GPU baseline (no mask, no sink) + print("\nStep 6: GPU Baseline (standard kernel, no mask)") + + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + Q_f16 = Q.astype(np.float16) + K_f16 = K.astype(np.float16) + V_f16 = V.astype(np.float16) + + result = runner.run(Q_f16, K_f16, V_f16, prob) + if result.success: + ok, max_abs, _ = validator.check(result.output, O_no_mask) + print( + f" GPU (no mask): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} " + f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}" + ) + else: + print(f" GPU error: {result.error}") + + # Summary + print("\n" + "=" * 70) + print(" Sink token attention: first N tokens always attended regardless of mask") + print(" Use case: StreamingLLM, long-context generation with attention anchors") + print(" Sink tokens preserve global context that causal masking would discard") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/23_batch_prefill_fmha.py b/dispatcher/examples/fmha/python/23_batch_prefill_fmha.py new file mode 100644 index 0000000000..dc9b54a4c5 --- /dev/null +++ b/dispatcher/examples/fmha/python/23_batch_prefill_fmha.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 23: Batch Prefill FMHA for SGLang/vLLM + +Demonstrates batch prefill with paged KV-cache, as used in serving +frameworks like SGLang and vLLM. Shows the KV page table configuration +(kv_indptr, kv_page_indices, kv_last_page_lens) for both: + - SGLang: 1D page table with indirect page lookup + - vLLM: 2D block table with per-sequence page arrays + +This example builds the page table metadata on CPU and validates the +attention computation. The prebuilt library only supports the basic +forward kernel, so the page table logic is demonstrated via CPU reference. + +Usage: + python3 23_batch_prefill_fmha.py + python3 23_batch_prefill_fmha.py --page-size 64 + python3 23_batch_prefill_fmha.py --num-seqs 8 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def build_sglang_page_table( + seq_lens_k: list, + page_size: int, + nhead_k: int, + hdim: int, +) -> dict: + """Build SGLang-style 1D page table for paged KV-cache. + + SGLang uses a flat 1D array of page indices. Each sequence's pages are + stored contiguously in the page_indices array, with indptr marking + boundaries. + + Returns dict with: + kv_indptr: [num_seqs + 1] cumulative page counts + kv_page_indices: [total_pages] global page IDs + kv_last_page_lens: [num_seqs] tokens in last page of each seq + num_total_pages: total pages allocated + kv_data_shape: shape of the paged KV pool + """ + num_seqs = len(seq_lens_k) + kv_indptr = np.zeros(num_seqs + 1, dtype=np.int32) + page_indices_list = [] + last_page_lens = np.zeros(num_seqs, dtype=np.int32) + + page_counter = 0 + for i, seqlen in enumerate(seq_lens_k): + num_pages = (seqlen + page_size - 1) // page_size + kv_indptr[i + 1] = kv_indptr[i] + num_pages + page_indices_list.extend(range(page_counter, page_counter + num_pages)) + last_page_lens[i] = seqlen - (num_pages - 1) * page_size + page_counter += num_pages + + kv_page_indices = np.array(page_indices_list, dtype=np.int32) + total_pages = page_counter + + return { + "kv_indptr": kv_indptr, + "kv_page_indices": kv_page_indices, + "kv_last_page_lens": last_page_lens, + "num_total_pages": total_pages, + "kv_data_shape": (total_pages, 2, nhead_k, page_size, hdim), + "layout": "sglang_1d", + } + + +def build_vllm_block_table( + seq_lens_k: list, + page_size: int, + nhead_k: int, + hdim: int, +) -> dict: + """Build vLLM-style 2D block table for paged KV-cache. + + vLLM uses a 2D array [num_seqs, max_blocks_per_seq] where each entry + is a block (page) index into the global KV pool. + + Returns dict with: + block_table: [num_seqs, max_blocks] page IDs (-1 = unused) + kv_last_page_lens: [num_seqs] tokens in last page of each seq + num_total_pages: total pages allocated + kv_data_shape: shape of the paged KV pool + """ + num_seqs = len(seq_lens_k) + pages_per_seq = [(s + page_size - 1) // page_size for s in seq_lens_k] + max_blocks = max(pages_per_seq) + + block_table = np.full((num_seqs, max_blocks), -1, dtype=np.int32) + last_page_lens = np.zeros(num_seqs, dtype=np.int32) + + page_counter = 0 + for i, (seqlen, num_pages) in enumerate(zip(seq_lens_k, pages_per_seq)): + for p in range(num_pages): + block_table[i, p] = page_counter + page_counter += 1 + last_page_lens[i] = seqlen - (num_pages - 1) * page_size + + return { + "block_table": block_table, + "kv_last_page_lens": last_page_lens, + "num_total_pages": page_counter, + "kv_data_shape": (page_counter, 2, nhead_k, page_size, hdim), + "layout": "vllm_2d", + } + + +def scatter_kv_to_pages( + K: np.ndarray, + V: np.ndarray, + page_table: dict, + page_size: int, +) -> np.ndarray: + """Scatter contiguous K,V into paged KV pool using page table. + + Args: + K: [nhead_k, seqlen_k, hdim] float32 (single sequence) + V: [nhead_k, seqlen_k, hdim] float32 + page_table: page indices for this sequence + page_size: tokens per page + """ + nhead_k, seqlen_k, hdim = K.shape + num_pages = (seqlen_k + page_size - 1) // page_size + + pages = np.zeros((num_pages, 2, nhead_k, page_size, hdim), dtype=np.float32) + for p in range(num_pages): + start = p * page_size + end = min(start + page_size, seqlen_k) + length = end - start + pages[p, 0, :, :length, :] = K[:, start:end, :] + pages[p, 1, :, :length, :] = V[:, start:end, :] + + return pages + + +def gather_kv_from_pages( + kv_pool: np.ndarray, + page_indices: np.ndarray, + seqlen_k: int, + page_size: int, +) -> tuple: + """Gather K,V from paged KV pool back to contiguous arrays. + + Returns: + K: [nhead_k, seqlen_k, hdim] + V: [nhead_k, seqlen_k, hdim] + """ + nhead_k = kv_pool.shape[2] + hdim = kv_pool.shape[4] + K = np.zeros((nhead_k, seqlen_k, hdim), dtype=np.float32) + V = np.zeros((nhead_k, seqlen_k, hdim), dtype=np.float32) + + for p, page_idx in enumerate(page_indices): + start = p * page_size + end = min(start + page_size, seqlen_k) + length = end - start + K[:, start:end, :] = kv_pool[page_idx, 0, :, :length, :] + V[:, start:end, :] = kv_pool[page_idx, 1, :, :length, :] + + return K, V + + +def main(): + parser = argparse.ArgumentParser( + description="Batch Prefill FMHA for SGLang/vLLM", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--nhead-q", type=int, default=16) + parser.add_argument("--nhead-k", type=int, default=4, help="KV heads (GQA)") + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--num-seqs", type=int, default=4, help="Sequences in batch") + args = parser.parse_args() + + print("=" * 70) + print("Example 23: Batch Prefill FMHA (SGLang/vLLM)") + print("=" * 70) + + seq_lens_q = [32, 64, 16, 48][: args.num_seqs] + seq_lens_k = [256, 512, 128, 384][: args.num_seqs] + + # Step 1: SGLang page table + print("\nStep 1: SGLang 1D Page Table") + + sglang_pt = build_sglang_page_table( + seq_lens_k, + args.page_size, + args.nhead_k, + args.hdim, + ) + + print(f" Page size: {args.page_size}") + print(f" Total pages: {sglang_pt['num_total_pages']}") + print(f" KV pool shape: {sglang_pt['kv_data_shape']}") + print(f" kv_indptr: {sglang_pt['kv_indptr']}") + print( + f" kv_page_indices: {sglang_pt['kv_page_indices'][:20]}{'...' if len(sglang_pt['kv_page_indices']) > 20 else ''}" + ) + print(f" last_page_lens: {sglang_pt['kv_last_page_lens']}") + + print("\n Per-sequence breakdown:") + print(f" {'Seq':>5} {'SeqQ':>6} {'SeqK':>6} {'Pages':>6} {'LastLen':>8}") + print(" " + "-" * 35) + for i in range(args.num_seqs): + n_pages = sglang_pt["kv_indptr"][i + 1] - sglang_pt["kv_indptr"][i] + print( + f" {i:>5} {seq_lens_q[i]:>6} {seq_lens_k[i]:>6} {n_pages:>6} {sglang_pt['kv_last_page_lens'][i]:>8}" + ) + + # Step 2: vLLM block table + print("\nStep 2: vLLM 2D Block Table") + + vllm_pt = build_vllm_block_table( + seq_lens_k, + args.page_size, + args.nhead_k, + args.hdim, + ) + + print(f" Block table shape: {vllm_pt['block_table'].shape}") + print(f" Total pages: {vllm_pt['num_total_pages']}") + for i in range(args.num_seqs): + row = vllm_pt["block_table"][i] + valid = row[row >= 0] + print(f" Seq {i}: pages={valid.tolist()}") + + # Step 3: Validate scatter/gather round-trip + print("\nStep 3: KV Page Scatter/Gather Validation") + + np.random.seed(42) + validator = FmhaValidator(rtol=1e-5, atol=1e-5) + + total_pages = sglang_pt["num_total_pages"] + kv_pool = np.zeros( + (total_pages, 2, args.nhead_k, args.page_size, args.hdim), + dtype=np.float32, + ) + + all_Q, all_K, all_V, all_O_ref = [], [], [], [] + + for i in range(args.num_seqs): + sq, sk = seq_lens_q[i], seq_lens_k[i] + Q_i = np.random.randn(args.nhead_q, sq, args.hdim).astype(np.float32) * 0.3 + K_i = np.random.randn(args.nhead_k, sk, args.hdim).astype(np.float32) * 0.3 + V_i = np.random.randn(args.nhead_k, sk, args.hdim).astype(np.float32) * 0.3 + + start_page = sglang_pt["kv_indptr"][i] + end_page = sglang_pt["kv_indptr"][i + 1] + page_indices = sglang_pt["kv_page_indices"][start_page:end_page] + + pages = scatter_kv_to_pages(K_i, V_i, page_indices, args.page_size) + for p_local, p_global in enumerate(page_indices): + kv_pool[p_global] = pages[p_local] + + K_rt, V_rt = gather_kv_from_pages(kv_pool, page_indices, sk, args.page_size) + + k_ok = np.allclose(K_i, K_rt, atol=1e-7) + v_ok = np.allclose(V_i, V_rt, atol=1e-7) + print( + f" Seq {i}: K round-trip={'OK' if k_ok else 'FAIL'} " + f"V round-trip={'OK' if v_ok else 'FAIL'}" + ) + + all_Q.append(Q_i) + all_K.append(K_i) + all_V.append(V_i) + + # Step 4: CPU attention per-sequence + print("\nStep 4: CPU Attention per Sequence (from Paged KV)") + + print(f"\n {'Seq':>5} {'SeqQ':>6} {'SeqK':>6} {'OutRange':>22} {'Scale':>10}") + print(" " + "-" * 50) + + for i in range(args.num_seqs): + sq, sk = seq_lens_q[i], seq_lens_k[i] + Q_i = all_Q[i][np.newaxis] # [1, nhead_q, sq, hdim] + K_i = all_K[i][np.newaxis] # [1, nhead_k, sk, hdim] + V_i = all_V[i][np.newaxis] # [1, nhead_k, sk, hdim] + + if args.nhead_q != args.nhead_k: + ratio = args.nhead_q // args.nhead_k + K_i_exp = np.repeat(K_i, ratio, axis=1) + V_i_exp = np.repeat(V_i, ratio, axis=1) + else: + K_i_exp, V_i_exp = K_i, V_i + + scale = 1.0 / (args.hdim**0.5) + O_i = cpu_attention_fwd(Q_i, K_i_exp, V_i_exp, scale) + all_O_ref.append(O_i) + + out_range = f"[{O_i.min():.4f}, {O_i.max():.4f}]" + print(f" {i:>5} {sq:>6} {sk:>6} {out_range:>22} {scale:>10.4f}") + + # Step 5: Memory layout comparison + print("\nStep 5: Memory Layout Analysis") + + contiguous_bytes = sum(2 * args.nhead_k * sk * args.hdim * 4 for sk in seq_lens_k) + paged_bytes = total_pages * 2 * args.nhead_k * args.page_size * args.hdim * 4 + overhead = (paged_bytes - contiguous_bytes) / contiguous_bytes * 100 + + print(f" Contiguous KV: {contiguous_bytes / 1024:.1f} KB") + print(f" Paged KV pool: {paged_bytes / 1024:.1f} KB") + print(f" Overhead: {overhead:.1f}% (due to page padding)") + print(f" Pages used: {total_pages}") + print(f" Avg tokens/seq: {sum(seq_lens_k) / args.num_seqs:.0f}") + + # Step 6: GPU API pattern + print("\nStep 6: GPU Kernel Configuration") + print(" NOTE: The prebuilt library uses basic forward kernels.") + print(" For batch prefill, compile a kernel with:") + print() + print(" FmhaSignature()") + print(" .family('batch_prefill')") + print(" .mode('group')") + print(" .paged_kv(true)") + print(" .kv_cache('vectorized', 'sglang', page_size)") + print(" .lse(true)") + print() + print(" FmhaKernelConfig codegen JSON:") + print(" 'family': 'batch_prefill',") + print(" 'mode': 'group',") + print(" 'paged_kv': true,") + print(" 'kv_memory_layout': 'vectorized',") + print(" 'kv_lookup_table': 'sglang' or 'vllm',") + print(f" 'page_size': {args.page_size}") + + # Step 7: GPU baseline (contiguous, no paging) + print("\nStep 7: GPU Baseline (contiguous KV, single sequence)") + + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + prob = FmhaProblem( + batch=1, + nhead_q=args.nhead_q, + nhead_k=args.nhead_k, + seqlen_q=64, + seqlen_k=256, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + Q_gpu = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float16) + K_gpu = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float16) + V_gpu = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float16) + + result = runner.run(Q_gpu, K_gpu, V_gpu, prob) + if result.success: + O_ref = cpu_attention_fwd( + Q_gpu.astype(np.float32), + K_gpu.astype(np.float32), + V_gpu.astype(np.float32), + prob.scale, + ) + ok, max_abs, _ = validator.check(result.output, O_ref) + print( + f" GPU baseline: time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} " + f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}" + ) + else: + print(f" GPU error: {result.error}") + + # Summary + print("\n" + "=" * 70) + print(" Batch prefill: serves multiple prefill requests in a single kernel launch") + print(" SGLang: 1D page table (kv_indptr + kv_page_indices)") + print(" vLLM: 2D block table [num_seqs, max_blocks]") + print( + f" Page size {args.page_size} -> {overhead:.1f}% memory overhead vs contiguous" + ) + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/24_vlayout_col_fmha.py b/dispatcher/examples/fmha/python/24_vlayout_col_fmha.py new file mode 100644 index 0000000000..28fc0814ad --- /dev/null +++ b/dispatcher/examples/fmha/python/24_vlayout_col_fmha.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 24: Column-Major V Layout FMHA + +Demonstrates column-major (vlayout="c") vs row-major (vlayout="r") for +the V tensor. In row-major, V is [batch, nhead, seqlen_k, hdim_v]; in +column-major, V is [batch, nhead, hdim_v, seqlen_k]. + +Column-major V can improve performance when hdim_v access patterns +benefit from the transposed layout (e.g., certain tile sizes or memory +coalescing characteristics on specific GPU architectures). + +The prebuilt library uses row-major V. This example shows both layouts +with CPU reference and validates correctness. + +Usage: + python3 24_vlayout_col_fmha.py + python3 24_vlayout_col_fmha.py --seqlen 512 + python3 24_vlayout_col_fmha.py --batch 4 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def cpu_attention_fwd_vlayout_col( + Q: np.ndarray, + K: np.ndarray, + V_col: np.ndarray, + scale: float, +) -> np.ndarray: + """CPU reference: attention with column-major V. + + Args: + Q: [batch, nhead_q, seqlen_q, hdim_q] float32 (row-major) + K: [batch, nhead_k, seqlen_k, hdim_q] float32 (row-major) + V_col: [batch, nhead_k, hdim_v, seqlen_k] float32 (column-major) + scale: softmax scale + + Returns: + O: [batch, nhead_q, seqlen_q, hdim_v] float32 + """ + V_row = V_col.transpose(0, 1, 3, 2) + return cpu_attention_fwd(Q, K, V_row, scale) + + +def analyze_strides(name: str, arr: np.ndarray, dim_names: list): + """Print stride information for a tensor.""" + strides_bytes = arr.strides + itemsize = arr.itemsize + strides_elems = tuple(s // itemsize for s in strides_bytes) + print(f" {name}:") + print(f" Shape: {arr.shape}") + print(f" Strides: {strides_elems} (elements)") + for i, (dname, s) in enumerate(zip(dim_names, strides_elems)): + contiguous = "(contiguous)" if i == len(dim_names) - 1 and s == 1 else "" + print(f" {dname}: stride={s} {contiguous}") + + +def main(): + parser = argparse.ArgumentParser( + description="Column-Major V Layout FMHA Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 24: Column-Major V Layout FMHA") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + # Step 1: Layout comparison + print("\nStep 1: V Tensor Layouts") + + np.random.seed(42) + V_row = np.ascontiguousarray( + (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32) + ) + V_col = np.ascontiguousarray(V_row.transpose(0, 1, 3, 2)) + + analyze_strides( + "V row-major [B, H, SeqK, Hdim]", + V_row, + ["batch", "nhead", "seqlen_k", "hdim_v"], + ) + analyze_strides( + "V col-major [B, H, Hdim, SeqK]", + V_col, + ["batch", "nhead", "hdim_v", "seqlen_k"], + ) + + print("\n Row-major: last dim is hdim_v -> sequential hdim access per token") + print(" Col-major: last dim is seqlen_k -> sequential token access per hdim") + + # Step 2: CPU reference for both layouts + print("\nStep 2: CPU Reference (both layouts)") + + Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32) + + O_from_row = cpu_attention_fwd(Q, K, V_row, prob.scale) + O_from_col = cpu_attention_fwd_vlayout_col(Q, K, V_col, prob.scale) + + validator = FmhaValidator(rtol=1e-5, atol=1e-5) + ok, max_abs, max_rel = validator.check(O_from_row, O_from_col) + + print( + f" O from row-major V: shape={O_from_row.shape} " + f"range=[{O_from_row.min():.4f}, {O_from_row.max():.4f}]" + ) + print( + f" O from col-major V: shape={O_from_col.shape} " + f"range=[{O_from_col.min():.4f}, {O_from_col.max():.4f}]" + ) + print(f" Max abs error: {max_abs:.2e}") + print(f" Match: {'PASS' if ok else 'FAIL'}") + + # Step 3: Memory access pattern analysis + print("\nStep 3: Memory Access Pattern Analysis") + + tile_sizes = [(128, 128), (64, 128), (128, 64)] + print("\n For P @ V matmul (P: [sq, sk] x V: [sk, hdim_v]):") + print(f" {'Tile(M,N)':>12} {'V Row Accesses':>18} {'V Col Accesses':>18}") + print(" " + "-" * 52) + + for tm, tn in tile_sizes: + row_access = f"sk_stride={args.hdim}" + col_access = "sk_stride=1" + print(f" {f'{tm}x{tn}':>12} {row_access:>18} {col_access:>18}") + + print("\n Row-major V: coalesced reads when accessing hdim_v (inner loop)") + print(" Col-major V: coalesced reads when accessing seqlen_k (inner loop)") + print(" Optimal layout depends on tile shape and GPU memory subsystem") + + # Step 4: Shape sweep with both layouts + print("\nStep 4: Correctness Sweep") + + shapes = [ + (1, 4, 64, 64, 64), + (2, 8, 128, 128, 128), + (1, 8, 256, 256, 128), + (2, 4, 128, 128, 64), + (1, 16, 64, 64, 128), + ] + + print(f"\n {'Shape':<32} {'MaxErr':>12} {'Status':>8}") + print(" " + "-" * 55) + + all_ok = True + for b, h, sq, sk, d in shapes: + Q_t = (np.random.randn(b, h, sq, d) * 0.3).astype(np.float32) + K_t = (np.random.randn(b, h, sk, d) * 0.3).astype(np.float32) + V_r = (np.random.randn(b, h, sk, d) * 0.3).astype(np.float32) + V_c = np.ascontiguousarray(V_r.transpose(0, 1, 3, 2)) + + scale = 1.0 / (d**0.5) + O_r = cpu_attention_fwd(Q_t, K_t, V_r, scale) + O_c = cpu_attention_fwd_vlayout_col(Q_t, K_t, V_c, scale) + + ok_t, max_abs_t, _ = validator.check(O_r, O_c) + all_ok = all_ok and ok_t + shape_str = f"B{b}_H{h}_S{sq}x{sk}_D{d}" + print(f" {shape_str:<32} {max_abs_t:>12.2e} {'PASS' if ok_t else 'FAIL':>8}") + + # Step 5: GPU API pattern + print("\nStep 5: GPU Kernel Configuration") + print(" NOTE: The prebuilt library uses row-major V (vlayout='r').") + print(" For column-major V, compile a kernel with vlayout='c':") + print() + print(" FmhaSignature()") + print(" .vlayout('c') // column-major V: [B, H, Hdim, SeqK]") + print() + print(" FmhaKernelConfig(vlayout='c', ...)") + + # Step 6: GPU baseline (row-major) + print("\nStep 6: GPU Baseline (row-major V)") + + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + Q_f16 = Q.astype(np.float16) + K_f16 = K.astype(np.float16) + V_f16 = V_row.astype(np.float16) + + result = runner.run(Q_f16, K_f16, V_f16, prob) + if result.success: + ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_from_row) + print( + f" GPU (row-major V): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} " + f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}" + ) + else: + print(f" GPU error: {result.error}") + + # Summary + print("\n" + "=" * 70) + print(" vlayout='r': V is [B, H, SeqK, Hdim] (default, row-major)") + print(" vlayout='c': V is [B, H, Hdim, SeqK] (column-major)") + print( + f" Both layouts produce identical results (verified: {'PASS' if all_ok else 'FAIL'})" + ) + print(" Choice depends on upstream memory layout and GPU tile access patterns") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/25_permutation_fmha.py b/dispatcher/examples/fmha/python/25_permutation_fmha.py new file mode 100644 index 0000000000..900cc802c1 --- /dev/null +++ b/dispatcher/examples/fmha/python/25_permutation_fmha.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 25: Input/Output Permutation FMHA + +Demonstrates different memory layouts for Q/K/V/O tensors via +input permutation (iperm) and output permutation (operm): + + iperm=0 (bshd): [batch, seqlen, nhead, hdim] -- used by some frameworks + iperm=1 (bhsd): [batch, nhead, seqlen, hdim] -- standard/default + + operm=0 (bshd): O is [batch, seqlen, nhead, hdim] + operm=1 (bhsd): O is [batch, nhead, seqlen, hdim] + +The prebuilt library uses bhsd layout (iperm=1, operm=1). This example +shows how to convert between layouts and validates correctness. + +Usage: + python3 25_permutation_fmha.py + python3 25_permutation_fmha.py --seqlen 256 + python3 25_permutation_fmha.py --batch 4 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def bhsd_to_bshd(x: np.ndarray) -> np.ndarray: + """Convert [batch, nhead, seqlen, hdim] -> [batch, seqlen, nhead, hdim].""" + return x.transpose(0, 2, 1, 3) + + +def bshd_to_bhsd(x: np.ndarray) -> np.ndarray: + """Convert [batch, seqlen, nhead, hdim] -> [batch, nhead, seqlen, hdim].""" + return x.transpose(0, 2, 1, 3) + + +def cpu_attention_fwd_bshd( + Q_bshd: np.ndarray, + K_bshd: np.ndarray, + V_bshd: np.ndarray, + scale: float, + operm: int = 0, +) -> np.ndarray: + """CPU reference with bshd input, configurable output layout. + + Args: + Q_bshd: [batch, seqlen_q, nhead_q, hdim_q] float32 + K_bshd: [batch, seqlen_k, nhead_k, hdim_q] float32 + V_bshd: [batch, seqlen_k, nhead_k, hdim_v] float32 + scale: softmax scale + operm: 0 -> output bshd, 1 -> output bhsd + + Returns: + O: float32 in requested layout + """ + Q_bhsd = bshd_to_bhsd(Q_bshd) + K_bhsd = bshd_to_bhsd(K_bshd) + V_bhsd = bshd_to_bhsd(V_bshd) + + O_bhsd = cpu_attention_fwd(Q_bhsd, K_bhsd, V_bhsd, scale) + + if operm == 0: + return bhsd_to_bshd(O_bhsd) + return O_bhsd + + +def describe_layout(arr: np.ndarray, layout_name: str, dim_names: list): + """Print layout details including strides.""" + itemsize = arr.itemsize + strides_elems = tuple(s // itemsize for s in arr.strides) + is_contiguous = arr.flags["C_CONTIGUOUS"] + print(f" {layout_name}:") + print(f" Shape: {arr.shape}") + print(f" Strides: {strides_elems} (elements)") + print(f" Contiguous: {is_contiguous}") + for dname, s in zip(dim_names, strides_elems): + print(f" {dname:>8}: stride={s}") + + +def main(): + parser = argparse.ArgumentParser( + description="Input/Output Permutation FMHA Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 25: Input/Output Permutation FMHA") + print("=" * 70) + + B, H, S, D = args.batch, args.nhead, args.seqlen, args.hdim + prob = FmhaProblem( + batch=B, + nhead_q=H, + nhead_k=H, + seqlen_q=S, + seqlen_k=S, + hdim_q=D, + hdim_v=D, + ) + + # Step 1: Layout definitions + print("\nStep 1: Layout Definitions") + + np.random.seed(42) + Q_bhsd = np.ascontiguousarray( + (np.random.randn(B, H, S, D) * 0.3).astype(np.float32) + ) + Q_bshd = np.ascontiguousarray(bhsd_to_bshd(Q_bhsd)) + + describe_layout(Q_bhsd, "bhsd (iperm=1)", ["batch", "nhead", "seqlen", "hdim"]) + describe_layout(Q_bshd, "bshd (iperm=0)", ["batch", "seqlen", "nhead", "hdim"]) + + print("\n Key difference:") + print(" bhsd: heads are contiguous -> good for per-head parallelism") + print(" bshd: tokens are contiguous -> good for sequence parallelism") + + # Step 2: All permutation combinations + print("\nStep 2: All Permutation Combinations (CPU Reference)") + + K_bhsd = (np.random.randn(B, H, S, D) * 0.3).astype(np.float32) + V_bhsd = (np.random.randn(B, H, S, D) * 0.3).astype(np.float32) + K_bshd = np.ascontiguousarray(bhsd_to_bshd(K_bhsd)) + V_bshd = np.ascontiguousarray(bhsd_to_bshd(V_bhsd)) + + O_ref_bhsd = cpu_attention_fwd(Q_bhsd, K_bhsd, V_bhsd, prob.scale) + O_ref_bshd = bhsd_to_bshd(O_ref_bhsd) + + validator = FmhaValidator(rtol=1e-5, atol=1e-5) + + combos = [ + ("iperm=1 operm=1", "bhsd->bhsd", Q_bhsd, K_bhsd, V_bhsd, 1, O_ref_bhsd), + ("iperm=1 operm=0", "bhsd->bshd", Q_bhsd, K_bhsd, V_bhsd, 0, O_ref_bshd), + ("iperm=0 operm=1", "bshd->bhsd", Q_bshd, K_bshd, V_bshd, 1, O_ref_bhsd), + ("iperm=0 operm=0", "bshd->bshd", Q_bshd, K_bshd, V_bshd, 0, O_ref_bshd), + ] + + print( + f"\n {'Config':<18} {'Transform':<14} {'OutShape':>24} {'MaxErr':>12} {'Status':>8}" + ) + print(" " + "-" * 80) + + all_ok = True + for name, transform, Q_in, K_in, V_in, operm, O_expected in combos: + if Q_in.shape[1] == H: + O_out = cpu_attention_fwd(Q_in, K_in, V_in, prob.scale) + if operm == 0: + O_out = bhsd_to_bshd(O_out) + else: + O_out = cpu_attention_fwd_bshd(Q_in, K_in, V_in, prob.scale, operm) + + ok, max_abs, _ = validator.check(O_out, O_expected) + all_ok = all_ok and ok + print( + f" {name:<18} {transform:<14} {str(O_out.shape):>24} {max_abs:>12.2e} {'PASS' if ok else 'FAIL':>8}" + ) + + # Step 3: Stride comparison table + print("\nStep 3: Stride Comparison") + + print(f"\n For B={B}, H={H}, S={S}, D={D}:") + print(f" {'Layout':>8} {'Dim Order':>16} {'Strides':>28} {'hdim contiguous':>18}") + print(" " + "-" * 74) + + bhsd_strides = (H * S * D, S * D, D, 1) + bshd_strides = (S * H * D, H * D, D, 1) + + print(f" {'bhsd':>8} {'B,H,S,D':>16} {str(bhsd_strides):>28} {'Yes':>18}") + print(f" {'bshd':>8} {'B,S,H,D':>16} {str(bshd_strides):>28} {'Yes':>18}") + + print("\n Stride analysis:") + print(f" bhsd: advancing 1 token = skip {D} elements (hdim)") + print(f" bshd: advancing 1 token = skip {H * D} elements (nhead * hdim)") + print(f" bhsd: advancing 1 head = skip {S * D} elements (seqlen * hdim)") + print(f" bshd: advancing 1 head = skip {D} elements (hdim)") + + # Step 4: Conversion cost + print("\nStep 4: Layout Conversion Cost") + + tensor_bytes = B * H * S * D * 4 + print(f" Tensor size: {tensor_bytes / 1024:.1f} KB (float32)") + print(" bhsd <-> bshd conversion: transpose(0,2,1,3) + contiguous copy") + print( + " If upstream provides bshd and kernel wants bhsd, conversion costs ~2x memory bandwidth" + ) + print(" Using iperm parameter avoids this copy by adjusting kernel strides") + + # Step 5: GPU run (bhsd, default layout) + print("\nStep 5: GPU Run (bhsd layout, iperm=1)") + + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + Q_f16 = Q_bhsd.astype(np.float16) + K_f16 = K_bhsd.astype(np.float16) + V_f16 = V_bhsd.astype(np.float16) + + result = runner.run(Q_f16, K_f16, V_f16, prob) + if result.success: + ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_ref_bhsd) + print( + f" GPU (bhsd): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} " + f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}" + ) + else: + print(f" GPU error: {result.error}") + + # Step 6: Kernel configuration for bshd + print("\nStep 6: GPU Kernel Configuration for bshd") + print(" The prebuilt library uses bhsd (iperm=1, operm=1).") + print(" For bshd input/output, the kernel adjusts internal strides:") + print() + print(" iperm=0: kernel reads Q,K,V as [B, S, H, D] with stride_head=D") + print(" iperm=1: kernel reads Q,K,V as [B, H, S, D] with stride_seq=D") + print(" operm=0: kernel writes O as [B, S, H, D]") + print(" operm=1: kernel writes O as [B, H, S, D]") + + # Summary + print("\n" + "=" * 70) + print(" iperm=0 (bshd): [B, S, H, D] -- sequence-first layout") + print(" iperm=1 (bhsd): [B, H, S, D] -- head-first layout (default)") + print(f" All 4 combinations validated: {'PASS' if all_ok else 'FAIL'}") + print(" Use iperm/operm to match upstream/downstream layout without copies") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/26_hdim_variety_fmha.py b/dispatcher/examples/fmha/python/26_hdim_variety_fmha.py new file mode 100644 index 0000000000..e24e0d0bdb --- /dev/null +++ b/dispatcher/examples/fmha/python/26_hdim_variety_fmha.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 26: Head Dimension Variety FMHA + +Demonstrates FMHA with multiple head dimensions (32, 64, 128, 256) and +asymmetric hdim (hdim_q != hdim_v). Different head dimensions require +different tile sizes and kernel configurations for optimal performance. + +The prebuilt library supports hdim=128 only. This example validates all +head dimensions via CPU reference and runs GPU for hdim=128. + +Usage: + python3 26_hdim_variety_fmha.py + python3 26_hdim_variety_fmha.py --seqlen 256 + python3 26_hdim_variety_fmha.py --batch 4 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + + +def recommended_tile(hdim: int) -> str: + """Suggest tile configuration for a given head dimension.""" + tiles = { + 32: "128x128x32x32x32x32", + 64: "128x64x32x64x32x64", + 128: "128x128x32x128x32x128", + 256: "128x128x32x256x32x256", + } + return tiles.get(hdim, f"auto (hdim={hdim})") + + +def compute_flops( + batch: int, nhead_q: int, sq: int, sk: int, hdim_q: int, hdim_v: int +) -> int: + """Compute FMHA FLOPs accounting for asymmetric hdim.""" + return 2 * batch * nhead_q * sq * sk * (hdim_q + hdim_v) + + +def main(): + parser = argparse.ArgumentParser( + description="Head Dimension Variety FMHA Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 26: Head Dimension Variety FMHA") + print("=" * 70) + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + # Step 1: Symmetric head dimensions + print("\nStep 1: Symmetric Head Dimensions (hdim_q == hdim_v)") + + hdims = [32, 64, 128, 256] + + print(f"\n {'hdim':>6} {'Shape':>30} {'Tile Config':>30} {'FLOPs':>14}") + print(" " + "-" * 84) + + for hdim in hdims: + shape = f"B{args.batch}_H{args.nhead}_S{args.seqlen}_D{hdim}" + tile = recommended_tile(hdim) + flops = compute_flops( + args.batch, args.nhead, args.seqlen, args.seqlen, hdim, hdim + ) + print(f" {hdim:>6} {shape:>30} {tile:>30} {flops:>14,}") + + # Step 2: CPU validation for each hdim + print("\nStep 2: CPU Validation") + + np.random.seed(42) + + print( + f"\n {'hdim_q':>7} {'hdim_v':>7} {'Scale':>10} {'OutRange':>22} {'SelfCheck':>10}" + ) + print(" " + "-" * 60) + + cpu_results = {} + for hdim in hdims: + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=hdim, + hdim_v=hdim, + ) + Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32) + + O_ref = cpu_attention_fwd(Q, K, V, prob.scale) + + self_ok = np.all(np.isfinite(O_ref)) + out_range = f"[{O_ref.min():.4f}, {O_ref.max():.4f}]" + print( + f" {hdim:>7} {hdim:>7} {prob.scale:>10.4f} {out_range:>22} {'OK' if self_ok else 'NaN!':>10}" + ) + + cpu_results[hdim] = (Q, K, V, O_ref, prob) + + # Step 3: Asymmetric head dimensions + print("\nStep 3: Asymmetric Head Dimensions (hdim_q != hdim_v)") + + asymmetric_configs = [ + (128, 64, "Large Q, small V: more attention capacity, compact output"), + (64, 128, "Small Q, large V: compact attention, rich output"), + (128, 256, "Standard Q, very large V: high-capacity value projection"), + (256, 128, "Large Q, standard V: wide attention field"), + (32, 128, "Tiny Q, standard V: minimal attention compute"), + ] + + print( + f"\n {'hdim_q':>7} {'hdim_v':>7} {'Q Shape':>22} {'O Shape':>22} {'MaxErr vs self':>16}" + ) + print(" " + "-" * 78) + + for hdim_q, hdim_v, desc in asymmetric_configs: + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=hdim_q, + hdim_v=hdim_v, + ) + Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32) + + out = cpu_attention_fwd(Q, K, V, prob.scale) + + O2 = cpu_attention_fwd(Q, K, V, prob.scale) + max_err = float(np.abs(out - O2).max()) + + print( + f" {hdim_q:>7} {hdim_v:>7} {str(prob.q_shape()):>22} {str(prob.o_shape()):>22} {max_err:>16.2e}" + ) + + print("\n Asymmetric hdim notes:") + for hdim_q, hdim_v, desc in asymmetric_configs: + print(f" hdim_q={hdim_q}, hdim_v={hdim_v}: {desc}") + + # Step 4: GPU validation (hdim=128) + print("\nStep 4: GPU Validation (hdim=128)") + + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=128, + hdim_v=128, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + gpu_tflops = 0.0 + gpu_time = 0.0 + if not setup.success: + print(f" JIT build failed: {setup.error}") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + Q, K, V, O_ref, prob = cpu_results[128] + Q_f16 = Q.astype(np.float16) + K_f16 = K.astype(np.float16) + V_f16 = V.astype(np.float16) + + result = runner.run(Q_f16, K_f16, V_f16, prob) + if result.success: + ok, max_abs, _ = validator.check(result.output, O_ref) + print( + f" GPU hdim=128: time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} " + f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}" + ) + + gpu_tflops = result.tflops + gpu_time = result.time_ms + else: + print(f" GPU error: {result.error}") + + # Step 5: Performance projection table + print("\nStep 5: Performance Summary Table") + + print( + f"\n {'hdim_q':>7} | {'hdim_v':>7} | {'FLOPs':>14} | {'Tile':>24} | {'GPU Support':>12}" + ) + print(" " + "-" * 78) + + for hdim in hdims: + flops = compute_flops( + args.batch, args.nhead, args.seqlen, args.seqlen, hdim, hdim + ) + tile = recommended_tile(hdim) + gpu_ok = "prebuilt" if hdim == 128 else "needs JIT" + print(f" {hdim:>7} | {hdim:>7} | {flops:>14,} | {tile:>24} | {gpu_ok:>12}") + + print(" " + "-" * 78) + + for hdim_q, hdim_v, _ in asymmetric_configs[:3]: + flops = compute_flops( + args.batch, args.nhead, args.seqlen, args.seqlen, hdim_q, hdim_v + ) + gpu_ok = "needs JIT" + print( + f" {hdim_q:>7} | {hdim_v:>7} | {flops:>14,} | {'asymmetric':>24} | {gpu_ok:>12}" + ) + + # Step 6: Kernel configuration per hdim + print("\nStep 6: Kernel Configuration Per Head Dimension") + print(" Each hdim requires a dedicated compiled kernel:") + print() + print( + " hdim=32: FmhaKernelConfig(hdim_q=32, hdim_v=32, " + "tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=32, tile_k1=32, tile_k0max=32)" + ) + print( + " hdim=64: FmhaKernelConfig(hdim_q=64, hdim_v=64, " + "tile_m0=128, tile_n0=64, tile_k0=32, tile_n1=64, tile_k1=32, tile_k0max=64)" + ) + print( + " hdim=128: FmhaKernelConfig(hdim_q=128, hdim_v=128, " + "tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=128, tile_k1=32, tile_k0max=128)" + ) + print( + " hdim=256: FmhaKernelConfig(hdim_q=256, hdim_v=256, " + "tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=256, tile_k1=32, tile_k0max=256)" + ) + print() + print(" Asymmetric: FmhaKernelConfig(hdim_q=128, hdim_v=64, ...)") + print(" tile_n1 tracks hdim_v; tile_k0max tracks hdim_q") + + # Summary + print("\n" + "=" * 70) + print(f" Supported symmetric hdims: {hdims}") + print(" Asymmetric hdim (hdim_q != hdim_v): fully supported") + print(" Tile sizes scale with hdim; larger hdim needs wider tiles") + if gpu_tflops > 0: + print(f" GPU baseline (hdim=128): {gpu_tflops:.2f} TFLOPS @ {gpu_time:.4f} ms") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/27_backward_dropout_fmha.py b/dispatcher/examples/fmha/python/27_backward_dropout_fmha.py new file mode 100644 index 0000000000..cc18b34c4b --- /dev/null +++ b/dispatcher/examples/fmha/python/27_backward_dropout_fmha.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 27: Backward Pass with Dropout FMHA + +Demonstrates the FMHA backward pass with dropout. The backward pass +computes dQ, dK, dV given dO (gradient of the output). When dropout is +applied during forward, the same dropout mask must be replayed during +backward for correctness. + +Key concepts: + - Deterministic mode (no atomics): reproducible gradients, may be slower + - Non-deterministic mode: uses atomicAdd for dQ, faster but non-reproducible + - store_randval: optionally store the dropout random values for debugging + +The prebuilt library only has a forward kernel. This example validates +the backward CPU reference and shows the API pattern. + +Usage: + python3 27_backward_dropout_fmha.py + python3 27_backward_dropout_fmha.py --dropout 0.2 + python3 27_backward_dropout_fmha.py --seqlen 128 --deterministic +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, +) + + +def cpu_attention_fwd_dropout( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + dropout_p: float, + seed: int = 42, +) -> tuple: + """CPU reference: forward with dropout, returning intermediates for backward. + + Returns: + O: [B, H, Sq, Dv] output + P_drop: [B, H, Sq, Sk] attention weights after dropout + lse: [B, H, Sq] log-sum-exp for numerical stability + drop_mask: [B, H, Sq, Sk] binary dropout mask + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + + lse = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1) + + rng = np.random.RandomState(seed) + drop_mask = (rng.rand(*P.shape) >= dropout_p).astype(np.float32) + drop_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0 + P_drop = P * drop_mask * drop_scale + + out = np.matmul(P_drop, V) + return out, P_drop, lse, drop_mask + + +def cpu_attention_bwd_dropout( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + lse: np.ndarray, + scale: float, + dropout_p: float, + drop_mask: np.ndarray, + deterministic: bool = False, +) -> tuple: + """CPU reference: backward with dropout. + + Args: + Q: [B, H, Sq, Dq] float32 + K: [B, H, Sk, Dq] float32 (already GQA-expanded if needed) + V: [B, H, Sk, Dv] float32 + out: [B, H, Sq, Dv] float32 (forward output) + dO: [B, H, Sq, Dv] float32 (output gradient) + lse: [B, H, Sq] float32 (log-sum-exp from forward) + scale: softmax scale + dropout_p: dropout probability + drop_mask: [B, H, Sq, Sk] binary mask from forward + deterministic: if True, avoid any non-deterministic accumulation + + Returns: + dQ: [B, H, Sq, Dq] + dK: [B, H, Sk, Dq] + dV: [B, H, Sk, Dv] + """ + drop_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0 + + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + P = np.exp(S - S_max) / np.exp(S - S_max).sum(axis=-1, keepdims=True) + + P_drop = P * drop_mask * drop_scale + + dV = np.matmul(P_drop.transpose(0, 1, 3, 2), dO) + + dP_drop = np.matmul(dO, V.transpose(0, 1, 3, 2)) + + dP = dP_drop * drop_mask * drop_scale + + D = (dO * out).sum(axis=-1, keepdims=True) + dS = P * (dP - D) * scale + + dQ = np.matmul(dS, K) + dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) + + return dQ, dK, dV + + +def main(): + parser = argparse.ArgumentParser( + description="Backward Pass with Dropout FMHA Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument( + "--dropout", type=float, default=0.1, help="Dropout probability" + ) + parser.add_argument( + "--deterministic", action="store_true", help="Use deterministic mode" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 27: Backward Pass with Dropout FMHA") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + # Step 1: Forward with dropout + print("\nStep 1: Forward Pass with Dropout") + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32) + + O_nodrop = cpu_attention_fwd(Q, K, V, prob.scale) + O_drop, P_drop, lse, drop_mask = cpu_attention_fwd_dropout( + Q, + K, + V, + prob.scale, + args.dropout, + seed=42, + ) + + print(f" Shape: {prob.q_shape()}") + print(f" Dropout: p={args.dropout}") + print( + f" Drop mask: {drop_mask.sum():.0f}/{drop_mask.size} kept " + f"({100 * drop_mask.mean():.1f}%, expected {100 * (1 - args.dropout):.1f}%)" + ) + print(f" O (no drop): range=[{O_nodrop.min():.4f}, {O_nodrop.max():.4f}]") + print(f" O (dropout): range=[{O_drop.min():.4f}, {O_drop.max():.4f}]") + print(f" LSE shape: {lse.shape}") + + # Step 2: Backward pass + print("\nStep 2: Backward Pass") + + np.random.seed(123) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + dQ, dK, dV = cpu_attention_bwd_dropout( + Q, + K, + V, + O_drop, + dO, + lse, + prob.scale, + args.dropout, + drop_mask, + deterministic=args.deterministic, + ) + + print(f" dQ shape: {dQ.shape} range=[{dQ.min():.6f}, {dQ.max():.6f}]") + print(f" dK shape: {dK.shape} range=[{dK.min():.6f}, {dK.max():.6f}]") + print(f" dV shape: {dV.shape} range=[{dV.min():.6f}, {dV.max():.6f}]") + print(f" Deterministic: {args.deterministic}") + + # Step 3: Verify gradient correctness via finite differences + print("\nStep 3: Gradient Verification (Finite Differences)") + + eps = 1e-3 + num_checks = 5 + rng = np.random.RandomState(99) + + print(f"\n Checking {num_checks} random elements per tensor:") + print( + f" {'Tensor':>8} {'Index':>24} {'Analytic':>14} {'Numerical':>14} {'RelErr':>12}" + ) + print(" " + "-" * 76) + + for tensor_name, param, grad in [("dQ", Q, dQ), ("dK", K, dK), ("dV", V, dV)]: + for _ in range(num_checks): + idx = tuple(rng.randint(0, s) for s in param.shape) + + param_plus = param.copy() + param_plus[idx] += eps + param_minus = param.copy() + param_minus[idx] -= eps + + if tensor_name == "dQ": + O_p, _, _, _ = cpu_attention_fwd_dropout( + param_plus, K, V, prob.scale, args.dropout, seed=42 + ) + O_m, _, _, _ = cpu_attention_fwd_dropout( + param_minus, K, V, prob.scale, args.dropout, seed=42 + ) + elif tensor_name == "dK": + O_p, _, _, _ = cpu_attention_fwd_dropout( + Q, param_plus, V, prob.scale, args.dropout, seed=42 + ) + O_m, _, _, _ = cpu_attention_fwd_dropout( + Q, param_minus, V, prob.scale, args.dropout, seed=42 + ) + else: + O_p, _, _, _ = cpu_attention_fwd_dropout( + Q, K, param_plus, prob.scale, args.dropout, seed=42 + ) + O_m, _, _, _ = cpu_attention_fwd_dropout( + Q, K, param_minus, prob.scale, args.dropout, seed=42 + ) + + numerical = (O_p * dO).sum() - (O_m * dO).sum() + numerical /= 2 * eps + analytic = grad[idx] + + rel_err = abs(analytic - numerical) / (abs(numerical) + 1e-8) + idx_str = str(idx) + print( + f" {tensor_name:>8} {idx_str:>24} {analytic:>14.6f} {numerical:>14.6f} {rel_err:>12.2e}" + ) + + # Step 4: Deterministic vs non-deterministic comparison + print("\nStep 4: Deterministic vs Non-Deterministic") + + dQ_det, dK_det, dV_det = cpu_attention_bwd_dropout( + Q, + K, + V, + O_drop, + dO, + lse, + prob.scale, + args.dropout, + drop_mask, + deterministic=True, + ) + dQ_ndet, dK_ndet, dV_ndet = cpu_attention_bwd_dropout( + Q, + K, + V, + O_drop, + dO, + lse, + prob.scale, + args.dropout, + drop_mask, + deterministic=False, + ) + + validator = FmhaValidator(rtol=1e-5, atol=1e-5) + + for name, g_det, g_ndet in [ + ("dQ", dQ_det, dQ_ndet), + ("dK", dK_det, dK_ndet), + ("dV", dV_det, dV_ndet), + ]: + ok, max_abs, _ = validator.check(g_det, g_ndet) + print( + f" {name}: det vs non-det max_err={max_abs:.2e} {'MATCH' if ok else 'DIFFER'}" + ) + + print("\n NOTE: In CPU reference both modes are identical.") + print(" On GPU, non-deterministic mode uses atomicAdd for dQ accumulation,") + print(" which can cause tiny floating-point differences across runs.") + + # Step 5: Dropout probability sweep + print("\nStep 5: Dropout Probability Sweep") + + probs = [0.0, 0.1, 0.2, 0.3, 0.5] + print( + f"\n {'p':>6} {'|dQ| mean':>12} {'|dK| mean':>12} {'|dV| mean':>12} {'Kept%':>8}" + ) + print(" " + "-" * 54) + + for p in probs: + O_p, _, _, dm = cpu_attention_fwd_dropout(Q, K, V, prob.scale, p, seed=42) + dQ_p, dK_p, dV_p = cpu_attention_bwd_dropout( + Q, + K, + V, + O_p, + dO, + lse, + prob.scale, + p, + dm, + ) + kept = 100 * dm.mean() + print( + f" {p:>6.2f} {np.abs(dQ_p).mean():>12.6f} {np.abs(dK_p).mean():>12.6f} " + f"{np.abs(dV_p).mean():>12.6f} {kept:>7.1f}%" + ) + + # Step 6: GPU API pattern + print("\nStep 6: GPU Backward Kernel Configuration") + print(" NOTE: The prebuilt library only has a forward kernel.") + print(" FMHA backward requires 3 kernel stages:") + print() + print(" Stage 1: bwd_dot_do_o -- compute D = rowsum(dO * O)") + print(" Stage 2: bwd_dq_dk_dv -- compute dQ, dK, dV") + print(" Stage 3: bwd_convert_dq -- convert accumulated dQ") + print() + print(" With dropout, the signature requires:") + print(" .dropout(true)") + print(" .store_randval(false) // or true to save random values") + print(f" .deterministic({'true' if args.deterministic else 'false'})") + + # Summary + print("\n" + "=" * 70) + print(" Backward with dropout: replays same mask from forward pass") + print(" Deterministic mode: reproducible but potentially slower on GPU") + print(" 3-stage backward: dot_do_o -> dq_dk_dv -> convert_dq") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/28_backward_dbias_fmha.py b/dispatcher/examples/fmha/python/28_backward_dbias_fmha.py new file mode 100644 index 0000000000..df614a7ede --- /dev/null +++ b/dispatcher/examples/fmha/python/28_backward_dbias_fmha.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 28: Backward Bias Gradient (dbias) FMHA + +Demonstrates computing the gradient of the elementwise attention bias +during the backward pass. When forward attention uses: + S = Q @ K^T * scale + bias +the backward pass must compute: + dbias = sum over batch of (dP) +where dP is the gradient of the attention probabilities. + +This is useful for learnable relative position biases (e.g., ALiBi +training, T5-style relative position embeddings). + +The prebuilt library only has a forward kernel. This example validates +the dbias CPU reference and shows the API pattern. + +Usage: + python3 28_backward_dbias_fmha.py + python3 28_backward_dbias_fmha.py --seqlen 128 + python3 28_backward_dbias_fmha.py --bias-type alibi +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaProblem, + cpu_attention_fwd, + detect_gpu_arch, +) + + +def make_elementwise_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Create a simple elementwise attention bias [nhead, seqlen_q, seqlen_k].""" + bias = np.zeros((nhead, seqlen_q, seqlen_k), dtype=np.float32) + for h in range(nhead): + for i in range(seqlen_q): + for j in range(seqlen_k): + bias[h, i, j] = -0.1 * abs(i - j) * (h + 1) / nhead + return bias + + +def make_alibi_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Create ALiBi-style attention bias [nhead, seqlen_q, seqlen_k]. + + ALiBi adds a linear penalty proportional to distance: + bias[h, i, j] = -slope_h * |i - j| + where slope_h decreases geometrically across heads. + """ + slopes = np.array([2 ** (-(8 * (h + 1) / nhead)) for h in range(nhead)]) + bias = np.zeros((nhead, seqlen_q, seqlen_k), dtype=np.float32) + for h in range(nhead): + for i in range(seqlen_q): + for j in range(seqlen_k): + bias[h, i, j] = -slopes[h] * abs(i - j) + return bias + + +def cpu_attention_fwd_bias( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + bias: np.ndarray, +) -> tuple: + """CPU forward with elementwise bias, returning intermediates. + + Args: + Q: [B, H, Sq, Dq] + K: [B, H, Sk, Dq] + V: [B, H, Sk, Dv] + bias: [H, Sq, Sk] broadcast over batch + + Returns: + O: [B, H, Sq, Dv] + P: [B, H, Sq, Sk] attention probabilities + lse: [B, H, Sq] log-sum-exp + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S = S + bias[np.newaxis, :, :, :] + + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + + lse = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1) + out = np.matmul(P, V) + return out, P, lse + + +def cpu_attention_bwd_dbias( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + P: np.ndarray, + scale: float, + bias: np.ndarray, +) -> tuple: + """CPU backward computing dQ, dK, dV, and dbias. + + Args: + Q, K, V: forward inputs [B, H, Sq/Sk, D] + out: forward output [B, H, Sq, Dv] + dO: output gradient [B, H, Sq, Dv] + P: attention probabilities [B, H, Sq, Sk] + scale: softmax scale + bias: [H, Sq, Sk] attention bias + + Returns: + dQ: [B, H, Sq, Dq] + dK: [B, H, Sk, Dq] + dV: [B, H, Sk, Dv] + dbias: [H, Sq, Sk] summed over batch dimension + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + + dV = np.matmul(P.transpose(0, 1, 3, 2), dO) + + dP = np.matmul(dO, V.transpose(0, 1, 3, 2)) + + D = (dO * out).sum(axis=-1, keepdims=True) + dS = P * (dP - D) * scale + + dQ = np.matmul(dS, K) + dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) + + dbias = dS.sum(axis=0) / scale + + return dQ, dK, dV, dbias + + +def main(): + parser = argparse.ArgumentParser( + description="Backward Bias Gradient (dbias) FMHA Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=4) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument( + "--bias-type", choices=["elementwise", "alibi"], default="elementwise" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 28: Backward Bias Gradient (dbias) FMHA") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + # Step 1: Create bias + print(f"\nStep 1: Create {args.bias_type.title()} Bias") + + if args.bias_type == "alibi": + bias = make_alibi_bias(args.nhead, args.seqlen, args.seqlen) + else: + bias = make_elementwise_bias(args.nhead, args.seqlen, args.seqlen) + + print(f" Bias shape: {bias.shape}") + print(f" Bias range: [{bias.min():.4f}, {bias.max():.4f}]") + print(f" Bias type: {args.bias_type}") + + for h in range(min(4, args.nhead)): + print( + f" Head {h}: range=[{bias[h].min():.4f}, {bias[h].max():.4f}] " + f"mean={bias[h].mean():.4f}" + ) + + # Step 2: Forward pass with bias + print("\nStep 2: Forward Pass with Bias") + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32) + + O_nobias = cpu_attention_fwd(Q, K, V, prob.scale) + O_bias, P, lse = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias) + + diff = np.abs(O_nobias - O_bias) + print(f" O (no bias): range=[{O_nobias.min():.4f}, {O_nobias.max():.4f}]") + print(f" O (biased): range=[{O_bias.min():.4f}, {O_bias.max():.4f}]") + print(f" Bias effect: max_diff={diff.max():.6e} mean_diff={diff.mean():.6e}") + + # Step 3: Backward pass with dbias + print("\nStep 3: Backward Pass (dQ, dK, dV, dbias)") + + np.random.seed(123) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + dQ, dK, dV, dbias = cpu_attention_bwd_dbias( + Q, + K, + V, + O_bias, + dO, + P, + prob.scale, + bias, + ) + + print(f" dQ shape: {dQ.shape} range=[{dQ.min():.6f}, {dQ.max():.6f}]") + print(f" dK shape: {dK.shape} range=[{dK.min():.6f}, {dK.max():.6f}]") + print(f" dV shape: {dV.shape} range=[{dV.min():.6f}, {dV.max():.6f}]") + print(f" dbias shape: {dbias.shape} range=[{dbias.min():.6f}, {dbias.max():.6f}]") + + # Step 4: Verify dbias via finite differences + print("\nStep 4: dbias Gradient Verification (Finite Differences)") + + eps = 1e-3 + num_checks = 8 + rng = np.random.RandomState(99) + + print( + f"\n {'Index':>20} {'Analytic':>14} {'Numerical':>14} {'RelErr':>12} {'Status':>8}" + ) + print(" " + "-" * 72) + + all_grad_ok = True + for _ in range(num_checks): + h = rng.randint(0, args.nhead) + i = rng.randint(0, args.seqlen) + j = rng.randint(0, args.seqlen) + + bias_plus = bias.copy() + bias_plus[h, i, j] += eps + bias_minus = bias.copy() + bias_minus[h, i, j] -= eps + + O_p, _, _ = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias_plus) + O_m, _, _ = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias_minus) + + numerical = ((O_p * dO).sum() - (O_m * dO).sum()) / (2 * eps) + analytic = dbias[h, i, j] + + rel_err = abs(analytic - numerical) / (abs(numerical) + 1e-8) + ok = rel_err < 1e-2 + all_grad_ok = all_grad_ok and ok + idx_str = f"({h},{i},{j})" + print( + f" {idx_str:>20} {analytic:>14.6f} {numerical:>14.6f} {rel_err:>12.2e} {'OK' if ok else 'FAIL':>8}" + ) + + # Step 5: dbias structure analysis + print("\nStep 5: dbias Structure Analysis") + + print("\n Per-head dbias statistics:") + print(f" {'Head':>6} {'Mean':>12} {'Std':>12} {'Min':>12} {'Max':>12}") + print(" " + "-" * 56) + + for h in range(min(8, args.nhead)): + db_h = dbias[h] + print( + f" {h:>6} {db_h.mean():>12.6f} {db_h.std():>12.6f} " + f"{db_h.min():>12.6f} {db_h.max():>12.6f}" + ) + + # Step 6: Batch size effect on dbias + print("\nStep 6: Batch Size Effect on dbias") + print(" dbias = sum of per-sample dS / scale over batch dimension") + print(" Larger batch -> dbias aggregates more gradient signal") + + batch_sizes = [1, 2, 4, 8] + print( + f"\n {'Batch':>6} {'|dbias| mean':>14} {'|dbias| max':>14} {'dbias std':>14}" + ) + print(" " + "-" * 52) + + for b in batch_sizes: + Q_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype( + np.float32 + ) + K_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype( + np.float32 + ) + V_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype( + np.float32 + ) + dO_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.1).astype( + np.float32 + ) + + O_b, P_b, lse_b = cpu_attention_fwd_bias(Q_b, K_b, V_b, prob.scale, bias) + _, _, _, dbias_b = cpu_attention_bwd_dbias( + Q_b, + K_b, + V_b, + O_b, + dO_b, + P_b, + prob.scale, + bias, + ) + print( + f" {b:>6} {np.abs(dbias_b).mean():>14.6f} {np.abs(dbias_b).max():>14.6f} " + f"{dbias_b.std():>14.6f}" + ) + + # Step 7: GPU API pattern + print("\nStep 7: GPU Kernel Configuration") + print(" NOTE: The prebuilt library only has a forward kernel without bias.") + print(" For backward with dbias, compile kernels with:") + print() + print(" Forward: FmhaSignature().bias('bias') // elementwise bias") + print(" Backward: FmhaSignature()") + print(" .family('bwd_dq_dk_dv')") + print(" .bias('bias')") + print(" .dbias(true) // enable dbias computation") + print() + print(" In codegen JSON:") + print(" 'bias': 'bias', // forward: elementwise bias") + print(" 'dbias': true, // backward: compute bias gradient") + + # Summary + print("\n" + "=" * 70) + print(" dbias = sum_batch(P * (dP - D)) (gradient of elementwise bias)") + print(f" Shape: [{args.nhead}, {args.seqlen}, {args.seqlen}] (same as bias)") + print(f" Gradient check: {'PASS' if all_grad_ok else 'FAIL'}") + print(" Use case: learnable relative position biases (ALiBi, T5, etc.)") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/29_sweep_seqlen.py b/dispatcher/examples/fmha/python/29_sweep_seqlen.py new file mode 100644 index 0000000000..49a030e750 --- /dev/null +++ b/dispatcher/examples/fmha/python/29_sweep_seqlen.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 29: Sweep Sequence Length + +Demonstrates how FMHA performance scales with sequence length. +FMHA has O(n^2) compute in seqlen (Q*K^T), so TFLOPS should increase +with longer sequences as the GPU becomes better utilized. + +Fixed: batch=2, nhead=8, hdim=128 +Sweep: seqlen in [32, 64, 128, 256, 512, 1024, 2048] + +Usage: + python3 29_sweep_seqlen.py + python3 29_sweep_seqlen.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + +BATCH = 2 +NHEAD = 8 +HDIM = 128 +SEQLENS = [32, 64, 128, 256, 512, 1024, 2048] + + +def main(): + parser = argparse.ArgumentParser(description="Sweep Sequence Length FMHA") + parser.add_argument("--arch", default=detect_gpu_arch()) + args = parser.parse_args() + + print("=" * 70) + print("Example 29: Sweep Sequence Length") + print("=" * 70) + + print(f"\n Fixed: batch={BATCH}, nhead={NHEAD}, hdim={HDIM}") + print(f" Sweep: seqlen in {SEQLENS}") + print(f" Arch: {args.arch}") + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + # Step 1: JIT-compile FMHA kernel + print("\nStep 1: JIT-Compile FMHA Kernel") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=HDIM, + hdim_v=HDIM, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + return 1 + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + # Step 2: Sweep + print("\nStep 2: Sequence Length Sweep") + + hdr = f" {'SeqLen':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}" + print(f"\n{hdr}") + print(" " + "-" * 60) + + np.random.seed(42) + results = [] + + for seqlen in SEQLENS: + prob = FmhaProblem( + batch=BATCH, + nhead_q=NHEAD, + nhead_k=NHEAD, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=HDIM, + hdim_v=HDIM, + ) + + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + res = runner.run(Q, K, V, prob) + if not res.success: + print( + f" {seqlen:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}" + ) + results.append((seqlen, False, 0.0, 0.0, 0.0)) + continue + + max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max()) + ok, _, _ = validator.check(res.output, O_ref) + tag = "PASS" if ok else "FAIL" + + print( + f" {seqlen:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}" + ) + results.append((seqlen, ok, res.time_ms, res.tflops, max_err)) + + # Step 3: Scaling analysis + print("\nStep 3: Scaling Analysis") + valid = [(s, t, tf) for s, ok, t, tf, _ in results if ok and tf > 0] + if len(valid) >= 2: + s0, _, tf0 = valid[0] + s_last, _, tf_last = valid[-1] + print(f" Shortest (seqlen={s0}): {tf0:.2f} TFLOPS") + print(f" Longest (seqlen={s_last}): {tf_last:.2f} TFLOPS") + print(f" Speedup: {tf_last / tf0:.1f}x TFLOPS improvement") + print(" Note: Longer sequences expose more parallelism to the GPU") + + # Summary + passed = sum(1 for _, ok, *_ in results if ok) + print("\n" + "=" * 70) + print(f" Results: {passed}/{len(results)} passed") + print(f" Fixed: B={BATCH} H={NHEAD} D={HDIM}") + print(f" Sweep: seqlen={SEQLENS}") + print(f" Status: {'PASS' if passed == len(results) else 'FAIL'}") + print("=" * 70) + + return 0 if passed == len(results) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/30_sweep_batch.py b/dispatcher/examples/fmha/python/30_sweep_batch.py new file mode 100644 index 0000000000..f7ba82a2c4 --- /dev/null +++ b/dispatcher/examples/fmha/python/30_sweep_batch.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 30: Sweep Batch Size + +Demonstrates how FMHA performance scales with batch size. +FMHA compute scales linearly with batch, so time should increase +linearly while TFLOPS remains roughly constant once the GPU is saturated. + +Fixed: seqlen=128, nhead=8, hdim=128 +Sweep: batch in [1, 2, 4, 8, 16, 32] + +Usage: + python3 30_sweep_batch.py + python3 30_sweep_batch.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + +SEQLEN = 128 +NHEAD = 8 +HDIM = 128 +BATCHES = [1, 2, 4, 8, 16, 32] + + +def main(): + parser = argparse.ArgumentParser(description="Sweep Batch Size FMHA") + parser.add_argument("--arch", default=detect_gpu_arch()) + args = parser.parse_args() + + print("=" * 70) + print("Example 30: Sweep Batch Size") + print("=" * 70) + + print(f"\n Fixed: seqlen={SEQLEN}, nhead={NHEAD}, hdim={HDIM}") + print(f" Sweep: batch in {BATCHES}") + print(f" Arch: {args.arch}") + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + # Step 1: JIT-compile FMHA kernel + print("\nStep 1: JIT-Compile FMHA Kernel") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=HDIM, + hdim_v=HDIM, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + return 1 + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + # Step 2: Sweep + print("\nStep 2: Batch Size Sweep") + + hdr = f" {'Batch':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}" + print(f"\n{hdr}") + print(" " + "-" * 60) + + np.random.seed(42) + results = [] + + for batch in BATCHES: + prob = FmhaProblem( + batch=batch, + nhead_q=NHEAD, + nhead_k=NHEAD, + seqlen_q=SEQLEN, + seqlen_k=SEQLEN, + hdim_q=HDIM, + hdim_v=HDIM, + ) + + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + res = runner.run(Q, K, V, prob) + if not res.success: + print( + f" {batch:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}" + ) + results.append((batch, False, 0.0, 0.0, 0.0)) + continue + + max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max()) + ok, _, _ = validator.check(res.output, O_ref) + tag = "PASS" if ok else "FAIL" + + print( + f" {batch:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}" + ) + results.append((batch, ok, res.time_ms, res.tflops, max_err)) + + # Step 3: Linearity analysis + print("\nStep 3: Linear Scaling Analysis") + valid = [(b, t, tf) for b, ok, t, tf, _ in results if ok and t > 0] + if len(valid) >= 2: + b0, t0, tf0 = valid[0] + b_last, t_last, tf_last = valid[-1] + batch_ratio = b_last / b0 + time_ratio = t_last / t0 + linearity = time_ratio / batch_ratio + + print( + f" Batch {b0} -> {b_last}: {batch_ratio:.0f}x batch, {time_ratio:.1f}x time" + ) + print(f" Linearity factor: {linearity:.2f} (1.0 = perfect linear scaling)") + print(f" TFLOPS range: {tf0:.2f} - {tf_last:.2f}") + + # Summary + passed = sum(1 for _, ok, *_ in results if ok) + print("\n" + "=" * 70) + print(f" Results: {passed}/{len(results)} passed") + print(f" Fixed: S={SEQLEN} H={NHEAD} D={HDIM}") + print(f" Sweep: batch={BATCHES}") + print(f" Status: {'PASS' if passed == len(results) else 'FAIL'}") + print("=" * 70) + + return 0 if passed == len(results) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/31_sweep_nhead.py b/dispatcher/examples/fmha/python/31_sweep_nhead.py new file mode 100644 index 0000000000..bd3374eaf7 --- /dev/null +++ b/dispatcher/examples/fmha/python/31_sweep_nhead.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 31: Sweep Number of Heads (MHA + GQA) + +Demonstrates FMHA performance across different head counts, including +Grouped Query Attention (GQA) where nhead_q > nhead_k. + +Part 1 - MHA sweep: nhead_q == nhead_k +Part 2 - GQA variants: nhead_q != nhead_k (multiple Q heads share K/V) + +Fixed: batch=2, seqlen=128, hdim=128 + +Usage: + python3 31_sweep_nhead.py + python3 31_sweep_nhead.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + +BATCH = 2 +SEQLEN = 128 +HDIM = 128 + +MHA_NHEADS = [1, 2, 4, 8, 16, 32] +GQA_CONFIGS = [ + (8, 1, "GQA 8:1"), + (16, 4, "GQA 4:1"), + (32, 8, "GQA 4:1"), +] + + +def run_sweep(runner, validator, configs, label): + """Run a sweep over (nhead_q, nhead_k) configurations.""" + hdr = f" {'nhead_q':>8} | {'nhead_k':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}" + print(f"\n{hdr}") + print(" " + "-" * 70) + + np.random.seed(42) + results = [] + + for nhead_q, nhead_k in configs: + prob = FmhaProblem( + batch=BATCH, + nhead_q=nhead_q, + nhead_k=nhead_k, + seqlen_q=SEQLEN, + seqlen_k=SEQLEN, + hdim_q=HDIM, + hdim_v=HDIM, + ) + + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16) + + O_ref = cpu_attention_fwd( + Q.astype(np.float32), + K.astype(np.float32), + V.astype(np.float32), + prob.scale, + ) + + res = runner.run(Q, K, V, prob) + if not res.success: + print( + f" {nhead_q:>8} | {nhead_k:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}" + ) + results.append((nhead_q, nhead_k, False, 0.0, 0.0, 0.0)) + continue + + max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max()) + ok, _, _ = validator.check(res.output, O_ref) + tag = "PASS" if ok else "FAIL" + + print( + f" {nhead_q:>8} | {nhead_k:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}" + ) + results.append((nhead_q, nhead_k, ok, res.time_ms, res.tflops, max_err)) + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Sweep Number of Heads FMHA") + parser.add_argument("--arch", default=detect_gpu_arch()) + args = parser.parse_args() + + print("=" * 70) + print("Example 31: Sweep Number of Heads (MHA + GQA)") + print("=" * 70) + + print(f"\n Fixed: batch={BATCH}, seqlen={SEQLEN}, hdim={HDIM}") + print(f" Arch: {args.arch}") + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + # Step 1: JIT-compile FMHA kernel + print("\nStep 1: JIT-Compile FMHA Kernel") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=HDIM, + hdim_v=HDIM, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if not setup.success: + print(f" JIT build failed: {setup.error}") + return 1 + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + # Step 2: MHA sweep (nhead_q == nhead_k) + print("\nStep 2: MHA Sweep (nhead_q == nhead_k)") + mha_configs = [(n, n) for n in MHA_NHEADS] + mha_results = run_sweep(runner, validator, mha_configs, "MHA") + + # Step 3: GQA sweep (nhead_q > nhead_k) + print("\nStep 3: GQA Sweep (nhead_q > nhead_k)") + print(" GQA: multiple Q heads share fewer K/V heads") + gqa_configs = [(nq, nk) for nq, nk, _ in GQA_CONFIGS] + gqa_results = run_sweep(runner, validator, gqa_configs, "GQA") + + # Step 4: Comparison + print("\nStep 4: MHA vs GQA Comparison") + all_results = mha_results + gqa_results + valid_mha = [(nq, nk, tf) for nq, nk, ok, _, tf, _ in mha_results if ok and tf > 0] + valid_gqa = [(nq, nk, tf) for nq, nk, ok, _, tf, _ in gqa_results if ok and tf > 0] + + if valid_mha: + best_mha = max(valid_mha, key=lambda x: x[2]) + print(f" Best MHA: nhead={best_mha[0]}, {best_mha[2]:.2f} TFLOPS") + if valid_gqa: + best_gqa = max(valid_gqa, key=lambda x: x[2]) + print( + f" Best GQA: nhead_q={best_gqa[0]}, nhead_k={best_gqa[1]}, {best_gqa[2]:.2f} TFLOPS" + ) + print(f" GQA saves K/V memory: {best_gqa[0]}:{best_gqa[1]} ratio") + + # Summary + passed = sum(1 for *_, ok, _, _, _ in all_results if ok) + total = len(all_results) + print("\n" + "=" * 70) + print(f" Results: {passed}/{total} passed") + print(f" Fixed: B={BATCH} S={SEQLEN} D={HDIM}") + print(f" MHA: nhead={MHA_NHEADS}") + print(f" GQA: {[(nq, nk) for nq, nk, _ in GQA_CONFIGS]}") + print(f" Status: {'PASS' if passed == total else 'FAIL'}") + print("=" * 70) + + return 0 if passed == total else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/32_sweep_hdim.py b/dispatcher/examples/fmha/python/32_sweep_hdim.py new file mode 100644 index 0000000000..d6fc095681 --- /dev/null +++ b/dispatcher/examples/fmha/python/32_sweep_hdim.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 32: Sweep Head Dimension + +Demonstrates FMHA across different head dimensions (32, 64, 128, 256). +The prebuilt library only supports hdim=128; other head dimensions are +validated via CPU reference only. + +Fixed: batch=2, nhead=8, seqlen=128 +Sweep: hdim in [32, 64, 128, 256] + +Usage: + python3 32_sweep_hdim.py + python3 32_sweep_hdim.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + FmhaValidator, + cpu_attention_fwd, + detect_gpu_arch, + setup_fmha_dispatcher, +) + +BATCH = 2 +NHEAD = 8 +SEQLEN = 128 +HDIMS = [32, 64, 128, 256] +GPU_SUPPORTED_HDIM = 128 + + +def main(): + parser = argparse.ArgumentParser(description="Sweep Head Dimension FMHA") + parser.add_argument("--arch", default=detect_gpu_arch()) + args = parser.parse_args() + + print("=" * 70) + print("Example 32: Sweep Head Dimension") + print("=" * 70) + + print(f"\n Fixed: batch={BATCH}, nhead={NHEAD}, seqlen={SEQLEN}") + print(f" Sweep: hdim in {HDIMS}") + print(f" Arch: {args.arch}") + print(f" Note: Only hdim={GPU_SUPPORTED_HDIM} runs on GPU (prebuilt lib)") + + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + + # Step 1: JIT-compile FMHA kernel (hdim=128) + print("\nStep 1: JIT-Compile FMHA Kernel") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=GPU_SUPPORTED_HDIM, + hdim_v=GPU_SUPPORTED_HDIM, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + runner = None + if not setup.success: + print(f" JIT build failed: {setup.error}") + print(" Will run CPU reference only") + else: + runner = setup.runner + print(f" JIT build: {setup.build_time_s:.1f}s") + + # Step 2: CPU reference for all hdims + print("\nStep 2: CPU Reference for All Head Dimensions") + + np.random.seed(42) + cpu_data = {} + + print( + f"\n {'hdim':>6} | {'Scale':>8} | {'FLOPs':>14} | {'O Range':>22} | {'Finite':<6}" + ) + print(" " + "-" * 66) + + for hdim in HDIMS: + prob = FmhaProblem( + batch=BATCH, + nhead_q=NHEAD, + nhead_k=NHEAD, + seqlen_q=SEQLEN, + seqlen_k=SEQLEN, + hdim_q=hdim, + hdim_v=hdim, + ) + + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + + O_ref = cpu_attention_fwd(Q, K, V, prob.scale) + is_finite = bool(np.all(np.isfinite(O_ref))) + o_range = f"[{O_ref.min():.4f}, {O_ref.max():.4f}]" + + print( + f" {hdim:>6} | {prob.scale:>8.4f} | {prob.num_ops:>14,} | {o_range:>22} | {'OK' if is_finite else 'NaN!':<6}" + ) + cpu_data[hdim] = (Q, K, V, O_ref, prob) + + # Step 3: GPU sweep (only hdim=128 supported) + print("\nStep 3: GPU Sweep") + + hdr = f" {'hdim':>6} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<10}" + print(f"\n{hdr}") + print(" " + "-" * 60) + + results = [] + + for hdim in HDIMS: + Q, K, V, O_ref, prob = cpu_data[hdim] + + if hdim != GPU_SUPPORTED_HDIM or runner is None: + print( + f" {hdim:>6} | {'---':>10} | {'---':>10} | {'---':>10} | {'CPU only':<10}" + ) + results.append((hdim, True, 0.0, 0.0, 0.0)) + continue + + Q_f16 = Q.astype(np.float16) + K_f16 = K.astype(np.float16) + V_f16 = V.astype(np.float16) + + res = runner.run(Q_f16, K_f16, V_f16, prob) + if not res.success: + print( + f" {hdim:>6} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<10}" + ) + results.append((hdim, False, 0.0, 0.0, 0.0)) + continue + + max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max()) + ok, _, _ = validator.check(res.output, O_ref) + tag = "PASS" if ok else "FAIL" + + print( + f" {hdim:>6} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<10}" + ) + results.append((hdim, ok, res.time_ms, res.tflops, max_err)) + + # Step 4: hdim analysis + print("\nStep 4: Head Dimension Analysis") + print(" Each hdim requires a dedicated compiled kernel:") + for hdim in HDIMS: + gpu_status = "prebuilt" if hdim == GPU_SUPPORTED_HDIM else "needs JIT" + tile_hint = f"tile_k0max={hdim}" + print(f" hdim={hdim:>3}: {gpu_status:<10} ({tile_hint})") + + print("\n Compute scales linearly with hdim (via Q*K^T and attn*V).") + print(" Larger hdim = more work per token, fewer tokens processed per CU.") + + # Summary + passed = sum(1 for _, ok, *_ in results if ok) + total = len(results) + print("\n" + "=" * 70) + print(f" Results: {passed}/{total} passed") + print(f" Fixed: B={BATCH} H={NHEAD} S={SEQLEN}") + print(f" Sweep: hdim={HDIMS}") + print(f" GPU: hdim={GPU_SUPPORTED_HDIM} only (prebuilt)") + print(f" Status: {'PASS' if passed == total else 'FAIL'}") + print("=" * 70) + + return 0 if passed == total else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/33_bwd_masks_fmha.py b/dispatcher/examples/fmha/python/33_bwd_masks_fmha.py new file mode 100644 index 0000000000..b5da6a2adc --- /dev/null +++ b/dispatcher/examples/fmha/python/33_bwd_masks_fmha.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 33: Backward Pass with Causal Masks + +Demonstrates the FMHA backward pass with causal mask variants: +1. no_mask -- Full attention (baseline) +2. top_left -- Causal mask aligned to top-left corner +3. bottom_right -- Causal mask aligned to bottom-right corner + +For each mask type: +- Forward: out = softmax(mask(Q @ K^T * scale)) @ V +- Backward: dQ, dK, dV via analytical gradients through the masked softmax + +CPU backward reference: + dP = dO @ V^T + D = rowsum(dO * out) (per-query-position scalar) + dS = P * (dP - D) + dQ = scale * dS @ K + dK = scale * dS^T @ Q + dV = P^T @ dO + +Usage: + python3 33_bwd_masks_fmha.py + python3 33_bwd_masks_fmha.py --seqlen-q 128 --seqlen-k 192 + python3 33_bwd_masks_fmha.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + setup_fmha_dispatcher, + detect_gpu_arch, +) + + +def make_causal_mask_top_left(seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Causal mask aligned to top-left: position i attends to positions <= i.""" + row = np.arange(seqlen_q).reshape(-1, 1) + col = np.arange(seqlen_k).reshape(1, -1) + return (col <= row).astype(np.float32) + + +def make_causal_mask_bottom_right(seqlen_q: int, seqlen_k: int) -> np.ndarray: + """Causal mask aligned to bottom-right: accounts for kv longer than q.""" + offset = seqlen_k - seqlen_q + row = np.arange(seqlen_q).reshape(-1, 1) + col = np.arange(seqlen_k).reshape(1, -1) + return (col <= row + offset).astype(np.float32) + + +def cpu_masked_fwd_with_intermediates( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + mask: np.ndarray, +) -> tuple: + """Forward pass with mask, returning out, P, and LSE for backward. + + Args: + Q: [B, H, Sq, D] K: [B, H, Sk, D] V: [B, H, Sk, Dv] + mask: [Sq, Sk] broadcast over batch and head + + Returns: (out, P, lse) + """ + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + mask_broad = mask[np.newaxis, np.newaxis, :, :] + S = np.where(mask_broad > 0, S, -1e9) + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + out = np.matmul(P, V) + lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32) + return out, P, lse + + +def cpu_masked_bwd( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + P: np.ndarray, + scale: float, +) -> tuple: + """CPU backward through masked softmax attention. + + P already incorporates the mask (zeroed-out positions have P=0). + + Returns: (dQ, dK, dV, D) + """ + D = (dO * out).sum(axis=-1, keepdims=True) + dP = np.matmul(dO, V.transpose(0, 1, 3, 2)) + dS = P * (dP - D) + dQ = np.matmul(dS, K) * scale + dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale + dV = np.matmul(P.transpose(0, 1, 3, 2), dO) + return dQ, dK, dV, D.squeeze(-1) + + +def main(): + parser = argparse.ArgumentParser(description="Backward Pass with Causal Masks") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen-q", type=int, default=64) + parser.add_argument("--seqlen-k", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 33: Backward Pass with Causal Masks") + print("=" * 70) + + sq, sk = args.seqlen_q, args.seqlen_k + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print(f"\n Problem: B={prob.batch} H={prob.nhead_q} Sq={sq} Sk={sk} D={args.hdim}") + print(f" Scale: {prob.scale:.6f}") + print(f" Arch: {args.arch}") + + # --- JIT compile a basic fp16 h128 fwd kernel --- + print("\n--- JIT Compilation ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s") + print(f" Library: {setup.library_path}") + print(" Note: Backward requires family='bwd' kernel (separate JIT)") + else: + print(f" JIT build: {setup.error}") + print(" Continuing with CPU reference only") + + # --- Generate data --- + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + # --- Build masks --- + masks = { + "no_mask": np.ones((sq, sk), dtype=np.float32), + "top_left": make_causal_mask_top_left(sq, sk), + "bottom_right": make_causal_mask_bottom_right(sq, sk), + } + + # --- Per-mask forward + backward --- + print( + f"\n {'Mask':<16} {'Density':>8} | {'|dQ|':>10} {'|dK|':>10} {'|dV|':>10}" + f" | {'dQ vs base':>10} {'dK vs base':>10} {'dV vs base':>10}" + ) + print(" " + "-" * 98) + + base_grads = None + all_grads = {} + + for name, mask in masks.items(): + density = mask.sum() / mask.size * 100 + + out, P, lse = cpu_masked_fwd_with_intermediates(Q, K, V, prob.scale, mask) + dQ, dK, dV, D = cpu_masked_bwd(Q, K, V, out, dO, P, prob.scale) + + dq_norm = float(np.abs(dQ).mean()) + dk_norm = float(np.abs(dK).mean()) + dv_norm = float(np.abs(dV).mean()) + + if base_grads is None: + base_grads = (dQ, dK, dV) + diff_str = f"{'---':>10} {'---':>10} {'---':>10}" + else: + dq_diff = float(np.abs(dQ - base_grads[0]).max()) + dk_diff = float(np.abs(dK - base_grads[1]).max()) + dv_diff = float(np.abs(dV - base_grads[2]).max()) + diff_str = f"{dq_diff:>10.2e} {dk_diff:>10.2e} {dv_diff:>10.2e}" + + print( + f" {name:<16} {density:>7.1f}% | {dq_norm:>10.4e} {dk_norm:>10.4e} {dv_norm:>10.4e}" + f" | {diff_str}" + ) + all_grads[name] = (dQ, dK, dV, D) + + # --- Detailed backward breakdown for each mask --- + print("\n--- Backward Stage Details ---") + + for name, mask in masks.items(): + dQ, dK, dV, D = all_grads[name] + out, P, lse = cpu_masked_fwd_with_intermediates(Q, K, V, prob.scale, mask) + + print(f"\n [{name}]") + print(" Stage 1 (dot_do_o): D = rowsum(dO * out)") + print(f" D shape: {D.shape}, range: [{D.min():.6f}, {D.max():.6f}]") + print(" Stage 2 (dq_dk_dv):") + print(f" dQ range: [{dQ.min():.4e}, {dQ.max():.4e}]") + print(f" dK range: [{dK.min():.4e}, {dK.max():.4e}]") + print(f" dV range: [{dV.min():.4e}, {dV.max():.4e}]") + + p_sparsity = (P < 1e-9).sum() / P.size * 100 + print(f" P sparsity (< 1e-9): {p_sparsity:.1f}%") + + # --- Gradient norm comparison across masks --- + print("\n--- Gradient L2 Norms ---") + print(f"\n {'Mask':<16} {'||dQ||_2':>12} {'||dK||_2':>12} {'||dV||_2':>12}") + print(" " + "-" * 54) + + for name in masks: + dQ, dK, dV, _ = all_grads[name] + l2_dq = float(np.sqrt((dQ**2).sum())) + l2_dk = float(np.sqrt((dK**2).sum())) + l2_dv = float(np.sqrt((dV**2).sum())) + print(f" {name:<16} {l2_dq:>12.4e} {l2_dk:>12.4e} {l2_dv:>12.4e}") + + # --- Mask pattern visualization --- + print("\n--- Mask Patterns (first 8x8 corner) ---") + view = min(8, sq, sk) + for name, mask in masks.items(): + corner = mask[:view, :view] + print(f"\n {name}:") + for r in range(view): + row_str = " ".join("█" if corner[r, c] > 0 else "·" for c in range(view)) + print(f" {row_str}") + + # --- Backward API pattern --- + print("\n--- Backward GPU API Pattern ---") + print(" The GPU backward for masked attention would use:") + print(" FmhaKernelConfig(family='bwd', mask='top_left', ...)") + print(" 3-stage backward plan:") + print(" Stage 1: bwd_dot_do_o -- D = rowsum(dO * out)") + print(" Stage 2: bwd_dq_dk_dv -- compute dQ, dK, dV with mask") + print(" Stage 3: bwd_convert_dq -- optional dtype conversion") + + # --- Summary --- + print("\n" + "=" * 70) + print(" Mask variants: no_mask, top_left, bottom_right") + print(" Backward math: dP = dO @ V^T, dS = P*(dP - D)") + print(" dQ = scale*dS@K, dK = scale*dS^T@Q, dV = P^T@dO") + print(" Causal effect: Masked positions get P=0, zeroing their gradient flow") + print(" GPU: Requires bwd-family JIT kernel with mask support") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/34_bwd_gqa_fmha.py b/dispatcher/examples/fmha/python/34_bwd_gqa_fmha.py new file mode 100644 index 0000000000..7bfdcc1788 --- /dev/null +++ b/dispatcher/examples/fmha/python/34_bwd_gqa_fmha.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 34: Backward Pass with GQA (Grouped-Query Attention) + +Demonstrates the FMHA backward pass when nhead_q != nhead_k. +GQA groups multiple query heads per KV head. The backward pass +must account for this by: + - Expanding K/V heads via np.repeat for dQ computation + - Summing dK/dV over query head groups back to KV head count + +Tested GQA ratios: 1:1 (MHA), 2:1, 4:1, 8:1 + +CPU backward reference: + K_exp = repeat(K, ratio) # [B, Hq, Sk, D] + V_exp = repeat(V, ratio) # [B, Hq, Sk, Dv] + dQ = scale * (P * (dO@V_exp^T - D)) @ K_exp + dK_exp = scale * (P * (dO@V_exp^T - D))^T @ Q + dV_exp = P^T @ dO + dK = sum_over_groups(dK_exp) # [B, Hk, Sk, D] + dV = sum_over_groups(dV_exp) # [B, Hk, Sk, Dv] + +Usage: + python3 34_bwd_gqa_fmha.py + python3 34_bwd_gqa_fmha.py --nhead-q 32 + python3 34_bwd_gqa_fmha.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + setup_fmha_dispatcher, + detect_gpu_arch, +) + + +def cpu_fwd_with_intermediates( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, +) -> tuple: + """Forward pass returning out, P, LSE (handles GQA via repeat).""" + nhead_q, nhead_k = Q.shape[1], K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + out = np.matmul(P, V) + lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32) + return out, P, lse + + +def cpu_bwd_gqa( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + P: np.ndarray, + scale: float, + nhead_q: int, + nhead_k: int, +) -> tuple: + """CPU backward with GQA head grouping. + + P is already computed on expanded heads [B, Hq, Sq, Sk]. + K, V are original (unexpanded) [B, Hk, Sk, D]. + + Returns: (dQ, dK, dV) where dK/dV have shape [B, Hk, Sk, ...] + """ + ratio = nhead_q // nhead_k + K_exp = np.repeat(K, ratio, axis=1) + V_exp = np.repeat(V, ratio, axis=1) + + D = (dO * out).sum(axis=-1, keepdims=True) + dP = np.matmul(dO, V_exp.transpose(0, 1, 3, 2)) + dS = P * (dP - D) + + dQ = np.matmul(dS, K_exp) * scale + + dK_exp = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale + dV_exp = np.matmul(P.transpose(0, 1, 3, 2), dO) + + B = Q.shape[0] + Sk, Dq = K.shape[2], K.shape[3] + Dv = V.shape[3] + + dK = dK_exp.reshape(B, nhead_k, ratio, Sk, Dq).sum(axis=2) + dV = dV_exp.reshape(B, nhead_k, ratio, Sk, Dv).sum(axis=2) + + return dQ, dK, dV + + +def main(): + parser = argparse.ArgumentParser(description="Backward Pass with GQA") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead-q", type=int, default=16) + parser.add_argument("--seqlen", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 34: Backward Pass with GQA") + print("=" * 70) + + hq = args.nhead_q + + gqa_ratios = [] + for ratio in [1, 2, 4, 8]: + if hq % ratio == 0 and hq // ratio >= 1: + gqa_ratios.append(ratio) + + print(f"\n nhead_q: {hq}") + print(f" Ratios: {', '.join(f'{r}:1' for r in gqa_ratios)}") + print(f" Problem: B={args.batch} S={args.seqlen} D={args.hdim}") + + # --- JIT compile a basic fp16 h128 fwd kernel --- + print("\n--- JIT Compilation ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s") + print(" Note: Backward GQA requires bwd-family kernel (separate JIT)") + else: + print(f" JIT build: {setup.error}") + print(" Continuing with CPU reference only") + + # --- Sweep GQA ratios --- + print("\n--- Backward Gradients per GQA Ratio ---") + print( + f"\n {'#':<3} {'Ratio':<8} {'Hq':>4} {'Hk':>4} " + f"| {'|dQ| mean':>10} {'|dK| mean':>10} {'|dV| mean':>10} " + f"| {'dK shape':>18} {'dV shape':>18}" + ) + print(" " + "-" * 104) + + all_results = {} + + for i, ratio in enumerate(gqa_ratios, 1): + hk = hq // ratio + prob = FmhaProblem( + batch=args.batch, + nhead_q=hq, + nhead_k=hk, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + np.random.seed(42 + i) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + out, P, lse = cpu_fwd_with_intermediates(Q, K, V, prob.scale) + dQ, dK, dV = cpu_bwd_gqa(Q, K, V, out, dO, P, prob.scale, hq, hk) + + dq_mean = float(np.abs(dQ).mean()) + dk_mean = float(np.abs(dK).mean()) + dv_mean = float(np.abs(dV).mean()) + + label = f"{ratio}:1" + if ratio == 1: + label += " MHA" + elif hk == 1: + label += " MQA" + + print( + f" {i:<3} {label:<8} {hq:>4} {hk:>4} " + f"| {dq_mean:>10.4e} {dk_mean:>10.4e} {dv_mean:>10.4e} " + f"| {str(dK.shape):>18} {str(dV.shape):>18}" + ) + all_results[ratio] = (dQ, dK, dV, Q, K, V, out, dO, P, prob) + + # --- Verify GQA backward via expanded MHA --- + print("\n--- GQA Backward Equivalence Check ---") + print(" Verifying: GQA bwd == MHA bwd with expanded K/V, then summed") + + for ratio in gqa_ratios: + if ratio == 1: + continue + + dQ_gqa, dK_gqa, dV_gqa, Q, K, V, out, dO, P, prob = all_results[ratio] + hk = hq // ratio + + K_exp = np.repeat(K, ratio, axis=1) + V_exp = np.repeat(V, ratio, axis=1) + + O_mha, P_mha, _ = cpu_fwd_with_intermediates(Q, K_exp, V_exp, prob.scale) + dQ_mha, dK_mha, dV_mha = cpu_bwd_gqa( + Q, + K_exp, + V_exp, + O_mha, + dO, + P_mha, + prob.scale, + hq, + hq, + ) + + B = Q.shape[0] + Sk = K.shape[2] + dK_mha_grouped = dK_mha.reshape(B, hk, ratio, Sk, K.shape[3]).sum(axis=2) + dV_mha_grouped = dV_mha.reshape(B, hk, ratio, Sk, V.shape[3]).sum(axis=2) + + dq_err = float(np.abs(dQ_gqa - dQ_mha).max()) + dk_err = float(np.abs(dK_gqa - dK_mha_grouped).max()) + dv_err = float(np.abs(dV_gqa - dV_mha_grouped).max()) + + tag = "PASS" if max(dq_err, dk_err, dv_err) < 1e-5 else "FAIL" + print( + f" Ratio {ratio}:1 -- dQ err={dq_err:.2e} dK err={dk_err:.2e} " + f"dV err={dv_err:.2e} {tag}" + ) + + # --- Gradient accumulation analysis --- + print("\n--- Head-Group Gradient Accumulation ---") + print(" When ratio > 1, dK/dV are summed over query heads in each group.") + print(" Higher ratio -> more terms summed -> larger gradient magnitudes.\n") + + print(f" {'Ratio':<8} {'||dK||_2':>12} {'||dV||_2':>12} {'dK/dV ratio':>12}") + print(" " + "-" * 48) + + for ratio in gqa_ratios: + dQ, dK, dV, *_ = all_results[ratio] + l2_dk = float(np.sqrt((dK**2).sum())) + l2_dv = float(np.sqrt((dV**2).sum())) + dk_dv_ratio = l2_dk / (l2_dv + 1e-12) + print(f" {ratio}:1{'':<4} {l2_dk:>12.4e} {l2_dv:>12.4e} {dk_dv_ratio:>12.2f}") + + # --- Backward GPU API pattern --- + print("\n--- Backward GPU API Pattern ---") + print(" GPU backward with GQA dispatches with nhead_q != nhead_k.") + print(" The dq_dk_dv kernel handles head grouping internally:") + print(" - dQ: computed per query head (no grouping needed)") + print(" - dK, dV: accumulated across head groups via atomicAdd") + print(" or multi-buffer reduction (deterministic mode)") + + # --- Summary --- + print("\n" + "=" * 70) + print(f" GQA ratios tested: {len(gqa_ratios)}") + print(" Backward math: expand K/V -> compute grads -> sum dK/dV") + print(" Equivalence: GQA bwd == MHA(expanded) bwd + group sum") + print(" GPU: Requires bwd-family JIT kernel") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/35_bwd_bf16_fmha.py b/dispatcher/examples/fmha/python/35_bwd_bf16_fmha.py new file mode 100644 index 0000000000..2021ca22cc --- /dev/null +++ b/dispatcher/examples/fmha/python/35_bwd_bf16_fmha.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 35: Backward Pass with BF16 Data Type + +Demonstrates the FMHA backward pass with bfloat16 precision. + +BF16 differences from FP16: + - 8-bit exponent (same as fp32) vs fp16's 5-bit + - 7-bit mantissa vs fp16's 10-bit + - Larger dynamic range but lower precision + +Tolerance guidance for backward: + - fp16 bwd: rtol=1.6e-2 typically sufficient + - bf16 bwd: rtol=3.2e-2 for hdim > 128 (less mantissa precision) + - bf16 bwd: rtol=2.0e-2 for hdim <= 128 + +CPU backward reference is computed in float32, then compared against +bf16-quantized inputs to measure the precision impact. + +Usage: + python3 35_bwd_bf16_fmha.py + python3 35_bwd_bf16_fmha.py --hdim 256 + python3 35_bwd_bf16_fmha.py --arch gfx942 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + setup_fmha_dispatcher, + detect_gpu_arch, + cpu_attention_bwd, +) + + +def to_bf16(arr: np.ndarray) -> np.ndarray: + """Convert float32 -> bfloat16 (stored as uint16 with bf16 bit pattern).""" + f32 = arr.astype(np.float32) + u32 = f32.view(np.uint32) + return (u32 >> 16).astype(np.uint16) + + +def bf16_to_f32(arr_u16: np.ndarray) -> np.ndarray: + """Convert bfloat16 (uint16) -> float32.""" + u32 = arr_u16.astype(np.uint32) << 16 + return u32.view(np.float32) + + +def cpu_fwd_with_intermediates( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, +) -> tuple: + """Forward pass returning out, P, LSE.""" + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + out = np.matmul(P, V) + lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32) + return out, P, lse + + +def get_bwd_tolerance(dtype: str, hdim: int) -> tuple: + """Recommended tolerances for backward pass validation.""" + if dtype == "bf16": + if hdim > 128: + return 3.2e-2, 3.2e-2 + return 2.0e-2, 2.0e-2 + return 1.6e-2, 1.6e-2 + + +def main(): + parser = argparse.ArgumentParser(description="Backward Pass with BF16") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 35: Backward Pass with BF16") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}") + print(f" Scale: {prob.scale:.6f}") + print(f" Arch: {args.arch}") + + # --- JIT compile a basic fp16 h128 fwd kernel --- + print("\n--- JIT Compilation ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s") + print( + " Note: Native bf16 bwd kernel requires separate JIT with data_type='bf16'" + ) + else: + print(f" JIT build: {setup.error}") + print(" Continuing with CPU reference only") + + # --- Generate data in both dtypes --- + np.random.seed(42) + Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO_f32 = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + Q_fp16 = Q_f32.astype(np.float16).astype(np.float32) + K_fp16 = K_f32.astype(np.float16).astype(np.float32) + V_fp16 = V_f32.astype(np.float16).astype(np.float32) + dO_fp16 = dO_f32.astype(np.float16).astype(np.float32) + + Q_bf16 = bf16_to_f32(to_bf16(Q_f32)) + K_bf16 = bf16_to_f32(to_bf16(K_f32)) + V_bf16 = bf16_to_f32(to_bf16(V_f32)) + dO_bf16 = bf16_to_f32(to_bf16(dO_f32)) + + # --- Quantization error comparison --- + print("\n--- Quantization Error ---") + print( + f"\n {'Tensor':<6} {'FP16 quant err':>16} {'BF16 quant err':>16} {'BF16/FP16':>10}" + ) + print(" " + "-" * 52) + + for name, orig, fp16, bf16 in [ + ("Q", Q_f32, Q_fp16, Q_bf16), + ("K", K_f32, K_fp16, K_bf16), + ("V", V_f32, V_fp16, V_bf16), + ("dO", dO_f32, dO_fp16, dO_bf16), + ]: + fp16_err = float(np.abs(orig - fp16).max()) + bf16_err = float(np.abs(orig - bf16).max()) + ratio = bf16_err / (fp16_err + 1e-15) + print(f" {name:<6} {fp16_err:>16.2e} {bf16_err:>16.2e} {ratio:>10.1f}x") + + # --- Backward with both dtypes --- + print("\n--- Backward Gradients: FP16 vs BF16 Inputs ---") + + dtype_configs = [ + ("fp16", Q_fp16, K_fp16, V_fp16, dO_fp16), + ("bf16", Q_bf16, K_bf16, V_bf16, dO_bf16), + ] + + grad_results = {} + for dtype_name, Q_d, K_d, V_d, dO_d in dtype_configs: + out, P, lse = cpu_fwd_with_intermediates(Q_d, K_d, V_d, prob.scale) + dQ, dK, dV = cpu_attention_bwd(Q_d, K_d, V_d, out, dO_d, P, prob.scale) + grad_results[dtype_name] = (dQ, dK, dV) + + print(f"\n {'Dtype':<6} {'|dQ| mean':>12} {'|dK| mean':>12} {'|dV| mean':>12}") + print(" " + "-" * 48) + for dtype_name in ["fp16", "bf16"]: + dQ, dK, dV = grad_results[dtype_name] + print( + f" {dtype_name:<6} {np.abs(dQ).mean():>12.4e} " + f"{np.abs(dK).mean():>12.4e} {np.abs(dV).mean():>12.4e}" + ) + + # --- Cross-dtype gradient difference --- + print("\n--- FP16 vs BF16 Backward Difference ---") + dQ_fp, dK_fp, dV_fp = grad_results["fp16"] + dQ_bf, dK_bf, dV_bf = grad_results["bf16"] + + print( + f"\n {'Grad':<6} {'Max abs diff':>14} {'Mean abs diff':>14} {'Max rel diff':>14}" + ) + print(" " + "-" * 52) + for name, g_fp, g_bf in [ + ("dQ", dQ_fp, dQ_bf), + ("dK", dK_fp, dK_bf), + ("dV", dV_fp, dV_bf), + ]: + abs_diff = np.abs(g_fp - g_bf) + max_abs = float(abs_diff.max()) + mean_abs = float(abs_diff.mean()) + max_rel = float((abs_diff / (np.abs(g_fp) + 1e-8)).max()) + print(f" {name:<6} {max_abs:>14.4e} {mean_abs:>14.4e} {max_rel:>14.4e}") + + # --- Tolerance analysis for different hdims --- + print("\n--- Recommended Backward Tolerances ---") + print(f"\n {'Dtype':<6} {'hdim':>6} {'rtol':>10} {'atol':>10} {'Note'}") + print(" " + "-" * 54) + for dtype in ["fp16", "bf16"]: + for hdim in [64, 128, 256]: + rtol, atol = get_bwd_tolerance(dtype, hdim) + note = "" + if dtype == "bf16" and hdim > 128: + note = "<-- relaxed for large hdim" + print(f" {dtype:<6} {hdim:>6} {rtol:>10.1e} {atol:>10.1e} {note}") + + # --- Validate backward with appropriate tolerances --- + print("\n--- Validation Against F32 Reference ---") + out_f32, P_f32, _ = cpu_fwd_with_intermediates(Q_f32, K_f32, V_f32, prob.scale) + dQ_ref, dK_ref, dV_ref = cpu_attention_bwd( + Q_f32, + K_f32, + V_f32, + out_f32, + dO_f32, + P_f32, + prob.scale, + ) + + for dtype_name in ["fp16", "bf16"]: + rtol, atol = get_bwd_tolerance(dtype_name, args.hdim) + dQ, dK, dV = grad_results[dtype_name] + + print(f"\n [{dtype_name}] rtol={rtol:.1e}, atol={atol:.1e}") + for gname, g, g_ref in [ + ("dQ", dQ, dQ_ref), + ("dK", dK, dK_ref), + ("dV", dV, dV_ref), + ]: + max_err = float(np.abs(g - g_ref).max()) + ok = bool(np.allclose(g, g_ref, rtol=rtol, atol=atol)) + print(f" {gname}: max_err={max_err:.4e} {'PASS' if ok else 'FAIL'}") + + # --- BF16 backward GPU API pattern --- + print("\n--- BF16 Backward GPU API Pattern ---") + print(" Native bf16 backward kernel:") + print(" FmhaKernelConfig(family='bwd', data_type='bf16', ...)") + print(" Internal accumulation stays in fp32 for numerical stability.") + print(" Stage 3 (convert_dq) converts fp32 accumulator -> bf16 output.") + print(" BF16 advantage: wider dynamic range prevents overflow in") + print(" intermediate products (S = Q @ K^T) for large sequences.") + + # --- Summary --- + print("\n" + "=" * 70) + print(" Data types: fp16 (10-bit mantissa) vs bf16 (7-bit mantissa)") + print(" Tolerances: bf16 bwd needs ~2x relaxed rtol vs fp16") + rtol_used, _ = get_bwd_tolerance("bf16", args.hdim) + print(f" Current hdim: {args.hdim} -> bf16 rtol={rtol_used:.1e}") + print(" GPU: Requires bwd-family JIT kernel with data_type='bf16'") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/36_bwd_benchmark_fmha.py b/dispatcher/examples/fmha/python/36_bwd_benchmark_fmha.py new file mode 100644 index 0000000000..1a40533881 --- /dev/null +++ b/dispatcher/examples/fmha/python/36_bwd_benchmark_fmha.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 36: Backward Pass Benchmark + +Benchmarks the FMHA backward pass across problem sizes. The backward +pass is approximately 4x the forward FLOPS because it computes dQ, dK, +and dV through two matrix multiplications each (plus the dot_do_o stage). + +Backward FLOPS estimate: + FWD: 2 * B * H * Sq * Sk * (Dq + Dv) + BWD: ~4 * FWD_FLOPS + = 2 * B * H * Sq * Sk * Dq (dP = dO @ V^T, part of dS computation) + + 2 * B * H * Sq * Sk * Dq (dQ = dS @ K) + + 2 * B * H * Sq * Sk * Dq (dK = dS^T @ Q) + + 2 * B * H * Sq * Sk * Dv (dV = P^T @ dO) + +When GPU JIT is unavailable, benchmarks CPU reference instead. + +Usage: + python3 36_bwd_benchmark_fmha.py + python3 36_bwd_benchmark_fmha.py --repeat 5 + python3 36_bwd_benchmark_fmha.py --arch gfx942 +""" + +import sys +import time +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + setup_fmha_dispatcher, + detect_gpu_arch, + cpu_attention_fwd_with_intermediates, + cpu_attention_bwd, +) + + +cpu_fwd_with_intermediates = cpu_attention_fwd_with_intermediates + + +def bwd_flops(prob: FmhaProblem) -> int: + """Estimate backward FLOPS (~4x forward).""" + B, Hq, Sq, Sk = prob.batch, prob.nhead_q, prob.seqlen_q, prob.seqlen_k + Dq, Dv = prob.hdim_q, prob.hdim_v + fwd = 2 * B * Hq * Sq * Sk * (Dq + Dv) + return 4 * fwd + + +def main(): + parser = argparse.ArgumentParser(description="Backward Pass Benchmark") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--repeat", type=int, default=3, help="Benchmark iterations") + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--hdim", type=int, default=128) + args = parser.parse_args() + + print("=" * 70) + print("Example 36: Backward Pass Benchmark") + print("=" * 70) + + print(f"\n Arch: {args.arch}") + print(f" nhead: {args.nhead}") + print(f" hdim: {args.hdim}") + print(f" Repeat: {args.repeat}") + + # --- JIT compile a basic fp16 h128 fwd kernel --- + print("\n--- JIT Compilation ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s") + print(" Backward GPU kernel: Not available (bwd JIT tile structure issue)") + print(" Benchmarking CPU backward reference instead") + else: + print(f" JIT build: {setup.error}") + print(" Benchmarking CPU backward reference") + + # --- Benchmark configs --- + bench_configs = [ + (1, 64), + (1, 128), + (1, 256), + (1, 512), + (1, 1024), + (2, 64), + (2, 128), + (2, 256), + (2, 512), + (4, 64), + (4, 128), + (4, 256), + (8, 64), + (8, 128), + ] + + # --- FLOPS estimate table --- + print("\n--- FLOPS Estimates (BWD ~4x FWD) ---") + print( + f"\n {'Batch':>5} {'SeqLen':>7} | {'FWD FLOPS':>14} {'BWD FLOPS':>14} {'Ratio':>6}" + ) + print(" " + "-" * 52) + + for batch, seqlen in [(1, 128), (1, 1024), (4, 256), (8, 128)]: + prob = FmhaProblem( + batch=batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + fwd_ops = prob.num_ops + bwd_ops = bwd_flops(prob) + print( + f" {batch:>5} {seqlen:>7} | {fwd_ops:>14,} {bwd_ops:>14,} {bwd_ops / fwd_ops:>5.1f}x" + ) + + # --- CPU backward benchmark --- + print("\n--- CPU Backward Benchmark ---") + print( + f"\n {'Batch':>5} {'SeqLen':>7} | {'Time(ms)':>10} {'TFLOPS':>10}" + f" | {'dQ range':>22} {'Finite':>6}" + ) + print(" " + "-" * 76) + + all_tflops = [] + + for batch, seqlen in bench_configs: + prob = FmhaProblem( + batch=batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + ops = bwd_flops(prob) + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + out, P = cpu_fwd_with_intermediates(Q, K, V, prob.scale) + + times = [] + dQ = dK = dV = None + for _ in range(args.repeat): + t0 = time.perf_counter() + dQ, dK, dV = cpu_attention_bwd(Q, K, V, out, dO, P, prob.scale) + t1 = time.perf_counter() + times.append((t1 - t0) * 1000.0) + + avg_ms = sum(times) / len(times) + tflops = ops / (avg_ms * 1e-3) / 1e12 if avg_ms > 0 else 0.0 + all_tflops.append(tflops) + + is_finite = bool(np.all(np.isfinite(dQ))) + dq_range = f"[{dQ.min():.4e}, {dQ.max():.4e}]" + + print( + f" {batch:>5} {seqlen:>7} | {avg_ms:>10.4f} {tflops:>10.4f}" + f" | {dq_range:>22} {'OK' if is_finite else 'NaN!':>6}" + ) + + # --- Scaling analysis --- + print("\n--- Scaling Analysis ---") + print(" Backward time should scale as O(B * H * Sq * Sk * D).") + print(" Doubling seqlen -> ~4x time (quadratic in sequence length).\n") + + ref_configs = [(1, 128), (1, 256), (1, 512)] + ref_times = {} + for batch, seqlen in ref_configs: + prob = FmhaProblem( + batch=batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=seqlen, + seqlen_k=seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + out, P = cpu_fwd_with_intermediates(Q, K, V, prob.scale) + + t0 = time.perf_counter() + cpu_attention_bwd(Q, K, V, out, dO, P, prob.scale) + ref_times[seqlen] = (time.perf_counter() - t0) * 1000.0 + + if 128 in ref_times and ref_times[128] > 0: + base = ref_times[128] + print(f" {'SeqLen':>7} {'Time(ms)':>10} {'vs S=128':>10}") + print(" " + "-" * 30) + for sl in sorted(ref_times): + ratio = ref_times[sl] / base + print(f" {sl:>7} {ref_times[sl]:>10.4f} {ratio:>9.1f}x") + + # --- Summary --- + print("\n" + "=" * 70) + print(f" Configs tested: {len(bench_configs)}") + print(" BWD FLOPS: ~4x forward FLOPS") + if all_tflops: + print(f" CPU avg: {sum(all_tflops) / len(all_tflops):.4f} TFLOPS") + print(f" CPU peak: {max(all_tflops):.4f} TFLOPS") + print(" GPU: Requires bwd-family JIT kernel") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/37_bwd_deterministic_fmha.py b/dispatcher/examples/fmha/python/37_bwd_deterministic_fmha.py new file mode 100644 index 0000000000..a9188e33c6 --- /dev/null +++ b/dispatcher/examples/fmha/python/37_bwd_deterministic_fmha.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 37: Backward Pass Deterministic Mode + +Demonstrates deterministic vs non-deterministic backward computation. + +Non-deterministic mode (default): + - dQ is accumulated via atomicAdd across seqlen_k tiles + - Faster but produces slightly different results each run + - Acceptable for training where stochastic noise is tolerable + +Deterministic mode: + - Uses multi-buffer reduction instead of atomics + - Each tile writes to a separate buffer, then a final reduction sums them + - Bit-exact reproducible gradients across runs + - Slower due to extra memory and reduction pass + +CPU reference simulates both modes. On CPU, both modes are numerically +identical (no atomics), but this example demonstrates the API pattern +and compares GPU-style multi-buffer reduction semantics. + +Usage: + python3 37_bwd_deterministic_fmha.py + python3 37_bwd_deterministic_fmha.py --seqlen 128 + python3 37_bwd_deterministic_fmha.py --num-tiles 4 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + setup_fmha_dispatcher, + detect_gpu_arch, +) + + +def cpu_fwd_with_intermediates( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, +) -> tuple: + """Forward returning out, P, LSE.""" + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + out = np.matmul(P, V) + lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32) + return out, P, lse + + +def cpu_bwd_nondeterministic( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + P: np.ndarray, + scale: float, +) -> tuple: + """Standard backward (single accumulation). Returns (dQ, dK, dV).""" + D = (dO * out).sum(axis=-1, keepdims=True) + dP = np.matmul(dO, V.transpose(0, 1, 3, 2)) + dS = P * (dP - D) + dQ = np.matmul(dS, K) * scale + dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale + dV = np.matmul(P.transpose(0, 1, 3, 2), dO) + return dQ, dK, dV + + +def cpu_bwd_deterministic( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + P: np.ndarray, + scale: float, + num_tiles_k: int = 4, +) -> tuple: + """Deterministic backward with explicit multi-buffer reduction for dQ. + + Simulates the GPU pattern where seqlen_k is split into tiles, + each tile writes dQ to a separate buffer, then buffers are summed. + + Returns: (dQ, dK, dV, dQ_buffers) + """ + B, Hq, Sq, Dq = Q.shape + Sk = K.shape[2] + + D = (dO * out).sum(axis=-1, keepdims=True) + + tile_sk = max(1, Sk // num_tiles_k) + actual_tiles = (Sk + tile_sk - 1) // tile_sk + + dQ_buffers = np.zeros((actual_tiles, B, Hq, Sq, Dq), dtype=np.float32) + dK = np.zeros_like(K) + dV = np.zeros_like(V) + + for t in range(actual_tiles): + sk_start = t * tile_sk + sk_end = min(sk_start + tile_sk, Sk) + + K_tile = K[:, :, sk_start:sk_end, :] + V_tile = V[:, :, sk_start:sk_end, :] + P_tile = P[:, :, :, sk_start:sk_end] + + dP_tile = np.matmul(dO, V_tile.transpose(0, 1, 3, 2)) + dS_tile = P_tile * (dP_tile - D) + + dQ_buffers[t] = np.matmul(dS_tile, K_tile) * scale + dK[:, :, sk_start:sk_end, :] = ( + np.matmul(dS_tile.transpose(0, 1, 3, 2), Q) * scale + ) + dV[:, :, sk_start:sk_end, :] = np.matmul(P_tile.transpose(0, 1, 3, 2), dO) + + dQ = dQ_buffers.sum(axis=0) + return dQ, dK, dV, dQ_buffers + + +def main(): + parser = argparse.ArgumentParser(description="Backward Deterministic Mode") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--seqlen", type=int, default=64) + parser.add_argument("--hdim", type=int, default=128) + parser.add_argument( + "--num-tiles", + type=int, + default=4, + help="Number of seqlen_k tiles for deterministic mode", + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 37: Backward Pass Deterministic Mode") + print("=" * 70) + + prob = FmhaProblem( + batch=args.batch, + nhead_q=args.nhead, + nhead_k=args.nhead, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=args.hdim, + hdim_v=args.hdim, + ) + + print( + f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}" + ) + print(f" Tiles: {args.num_tiles} (seqlen_k split)") + print(f" Tile size: {max(1, args.seqlen // args.num_tiles)}") + + # --- JIT compile a basic fp16 h128 fwd kernel --- + print("\n--- JIT Compilation ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=args.hdim, + hdim_v=args.hdim, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s") + print(" Backward deterministic kernel: separate JIT with deterministic=True") + else: + print(f" JIT build: {setup.error}") + print(" Continuing with CPU reference only") + + # --- Generate data --- + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + out, P, lse = cpu_fwd_with_intermediates(Q, K, V, prob.scale) + + # --- Non-deterministic backward --- + print("\n--- Non-Deterministic Backward ---") + dQ_nd, dK_nd, dV_nd = cpu_bwd_nondeterministic(Q, K, V, out, dO, P, prob.scale) + + print(f" dQ range: [{dQ_nd.min():.4e}, {dQ_nd.max():.4e}]") + print(f" dK range: [{dK_nd.min():.4e}, {dK_nd.max():.4e}]") + print(f" dV range: [{dV_nd.min():.4e}, {dV_nd.max():.4e}]") + + # --- Deterministic backward --- + print(f"\n--- Deterministic Backward ({args.num_tiles} tiles) ---") + dQ_det, dK_det, dV_det, dQ_bufs = cpu_bwd_deterministic( + Q, + K, + V, + out, + dO, + P, + prob.scale, + num_tiles_k=args.num_tiles, + ) + + print(f" dQ range: [{dQ_det.min():.4e}, {dQ_det.max():.4e}]") + print(f" dK range: [{dK_det.min():.4e}, {dK_det.max():.4e}]") + print(f" dV range: [{dV_det.min():.4e}, {dV_det.max():.4e}]") + print(f" dQ buffers: {dQ_bufs.shape[0]} x {dQ_bufs.shape[1:]}") + + # --- Per-buffer analysis --- + print("\n--- Per-Tile dQ Buffer Analysis ---") + print(f"\n {'Tile':>6} {'|buf| mean':>12} {'|buf| max':>12} {'% of total':>12}") + print(" " + "-" * 46) + + total_l1 = float(np.abs(dQ_det).sum()) + for t in range(dQ_bufs.shape[0]): + buf = dQ_bufs[t] + buf_mean = float(np.abs(buf).mean()) + buf_max = float(np.abs(buf).max()) + buf_pct = float(np.abs(buf).sum()) / (total_l1 + 1e-15) * 100 + print(f" {t:>6} {buf_mean:>12.4e} {buf_max:>12.4e} {buf_pct:>11.1f}%") + + # --- Compare deterministic vs non-deterministic --- + print("\n--- Deterministic vs Non-Deterministic Comparison ---") + print(f"\n {'Grad':<6} {'Max abs diff':>14} {'Mean abs diff':>14} {'Match':>8}") + print(" " + "-" * 46) + + for name, g_det, g_nd in [ + ("dQ", dQ_det, dQ_nd), + ("dK", dK_det, dK_nd), + ("dV", dV_det, dV_nd), + ]: + abs_diff = np.abs(g_det - g_nd) + max_abs = float(abs_diff.max()) + mean_abs = float(abs_diff.mean()) + match = max_abs < 1e-6 + print( + f" {name:<6} {max_abs:>14.2e} {mean_abs:>14.2e} {'YES' if match else 'NO':>8}" + ) + + print("\n NOTE: On CPU, both modes produce identical results.") + print(" On GPU, non-deterministic mode uses atomicAdd for dQ,") + print(" causing order-dependent floating-point rounding differences.") + + # --- Reproducibility test --- + print("\n--- Reproducibility Test (Deterministic Mode) ---") + num_runs = 5 + dQ_runs = [] + for run in range(num_runs): + dQ_r, _, _, _ = cpu_bwd_deterministic( + Q, + K, + V, + out, + dO, + P, + prob.scale, + num_tiles_k=args.num_tiles, + ) + dQ_runs.append(dQ_r) + + max_variation = 0.0 + for i in range(1, num_runs): + diff = float(np.abs(dQ_runs[i] - dQ_runs[0]).max()) + max_variation = max(max_variation, diff) + + print(f" Runs: {num_runs}") + print(f" Max dQ variation across runs: {max_variation:.2e}") + print(f" Bit-exact reproducible: {'YES' if max_variation == 0.0 else 'NO'}") + + # --- Memory overhead analysis --- + print("\n--- Deterministic Mode Memory Overhead ---") + dq_size = Q.nbytes + buf_size = dQ_bufs.nbytes + overhead = buf_size / dq_size + + print(f" dQ single buffer: {dq_size:>10,} bytes") + print(f" dQ multi-buffer: {buf_size:>10,} bytes ({args.num_tiles} tiles)") + print(f" Memory overhead: {overhead:.1f}x") + print(f" Extra memory: {buf_size - dq_size:>10,} bytes") + + # --- GPU API pattern --- + print("\n--- GPU Deterministic API Pattern ---") + print(" Non-deterministic (default):") + print(" FmhaKernelConfig(family='bwd', deterministic=False)") + print(" dQ accumulated via atomicAdd (fast, non-reproducible)") + print() + print(" Deterministic:") + print(" FmhaKernelConfig(family='bwd', deterministic=True)") + print(" dQ via multi-buffer + final reduction (reproducible)") + print(" Requires extra workspace: num_tiles_k * sizeof(dQ)") + + # --- Summary --- + print("\n" + "=" * 70) + print(f" Tiles: {args.num_tiles}") + print(f" Memory overhead: {overhead:.1f}x for deterministic dQ") + print(" Reproducible: Deterministic mode guarantees bit-exact results") + print(" Performance: Deterministic ~10-20% slower on GPU (extra reduction)") + print(" GPU: Requires bwd-family JIT kernel") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/fmha/python/38_bwd_sweep_hdim_fmha.py b/dispatcher/examples/fmha/python/38_bwd_sweep_hdim_fmha.py new file mode 100644 index 0000000000..53f7b0bf63 --- /dev/null +++ b/dispatcher/examples/fmha/python/38_bwd_sweep_hdim_fmha.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 38: Backward Pass Head Dimension Sweep + +Sweeps hdim for the backward pass: 32, 64, 128, 256. + +Each hdim requires a dedicated compiled kernel because the tile +dimensions (tile_k0max, tile_n1) must match the head dimension. +This example shows which hdims the backward kernels can support +and computes CPU reference gradients for each. + +Backward kernel tile requirements per hdim: + hdim=32: tile_k0max=32, tile_n1=32 (small, fast compile) + hdim=64: tile_k0max=64, tile_n1=64 + hdim=128: tile_k0max=128, tile_n1=128 (standard LLM config) + hdim=256: tile_k0max=256, tile_n1=256 (large, slow compile) + +Fixed: batch=2, nhead=8, seqlen=64 + +Usage: + python3 38_bwd_sweep_hdim_fmha.py + python3 38_bwd_sweep_hdim_fmha.py --arch gfx942 + python3 38_bwd_sweep_hdim_fmha.py --seqlen 128 +""" + +import sys +import time +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from fmha_utils import ( + FmhaKernelConfig, + FmhaProblem, + setup_fmha_dispatcher, + detect_gpu_arch, + cpu_attention_bwd, +) + +HDIMS = [32, 64, 128, 256] +BATCH = 2 +NHEAD = 8 + + +def cpu_fwd_with_intermediates( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, +) -> tuple: + """Forward returning out, P, LSE.""" + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + S_sum = S_exp.sum(axis=-1, keepdims=True) + P = S_exp / S_sum + out = np.matmul(P, V) + lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32) + return out, P, lse + + +def bwd_flops(prob: FmhaProblem) -> int: + """Backward FLOPS (~4x forward).""" + return 4 * prob.num_ops + + +def main(): + parser = argparse.ArgumentParser(description="Backward Head Dimension Sweep") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--seqlen", type=int, default=64) + args = parser.parse_args() + + print("=" * 70) + print("Example 38: Backward Pass Head Dimension Sweep") + print("=" * 70) + + print(f"\n Fixed: batch={BATCH}, nhead={NHEAD}, seqlen={args.seqlen}") + print(f" Sweep: hdim in {HDIMS}") + print(f" Arch: {args.arch}") + + # --- JIT compile a basic fp16 h128 fwd kernel --- + print("\n--- JIT Compilation (hdim=128 fwd kernel) ---") + config = FmhaKernelConfig( + data_type="fp16", + hdim_q=128, + hdim_v=128, + gfx_arch=args.arch, + ) + setup = setup_fmha_dispatcher(config) + if setup.success: + print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s") + print(" Backward kernels for each hdim need separate JIT compilation") + else: + print(f" JIT build: {setup.error}") + print(" Continuing with CPU reference only") + + # --- Kernel tile requirements per hdim --- + print("\n--- Backward Kernel Tile Requirements ---") + print( + f"\n {'hdim':>6} | {'tile_k0max':>10} {'tile_n1':>8} {'tile_k0':>8}" + f" | {'scale':>8} | {'Status'}" + ) + print(" " + "-" * 62) + + for hdim in HDIMS: + tile_k0 = min(32, hdim) + bwd_status = "needs bwd JIT" + if hdim == 128 and setup.success: + bwd_status = "fwd only (JIT)" + scale = 1.0 / (hdim**0.5) + print( + f" {hdim:>6} | {hdim:>10} {hdim:>8} {tile_k0:>8}" + f" | {scale:>8.4f} | {bwd_status}" + ) + + # --- CPU backward for each hdim --- + print("\n--- CPU Backward Reference per Head Dimension ---") + print( + f"\n {'hdim':>6} | {'FWD ops':>12} {'BWD ops':>12}" + f" | {'|dQ| mean':>10} {'|dK| mean':>10} {'|dV| mean':>10}" + f" | {'Time(ms)':>10} {'Finite':>6}" + ) + print(" " + "-" * 96) + + all_results = {} + + for hdim in HDIMS: + prob = FmhaProblem( + batch=BATCH, + nhead_q=NHEAD, + nhead_k=NHEAD, + seqlen_q=args.seqlen, + seqlen_k=args.seqlen, + hdim_q=hdim, + hdim_v=hdim, + ) + + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32) + dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32) + + out, P, lse = cpu_fwd_with_intermediates(Q, K, V, prob.scale) + + t0 = time.perf_counter() + dQ, dK, dV = cpu_attention_bwd(Q, K, V, out, dO, P, prob.scale) + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + is_finite = bool( + np.all(np.isfinite(dQ)) + and np.all(np.isfinite(dK)) + and np.all(np.isfinite(dV)) + ) + fwd_ops = prob.num_ops + bwd_ops = bwd_flops(prob) + + print( + f" {hdim:>6} | {fwd_ops:>12,} {bwd_ops:>12,}" + f" | {np.abs(dQ).mean():>10.4e} {np.abs(dK).mean():>10.4e}" + f" {np.abs(dV).mean():>10.4e}" + f" | {elapsed_ms:>10.4f} {'OK' if is_finite else 'NaN!':>6}" + ) + all_results[hdim] = (dQ, dK, dV, out, P, Q, K, V, dO, prob) + + # --- Gradient norms vs hdim --- + print("\n--- Gradient L2 Norms vs Head Dimension ---") + print( + f"\n {'hdim':>6} | {'||dQ||_2':>12} {'||dK||_2':>12} {'||dV||_2':>12} | {'ratio dQ/dK':>12}" + ) + print(" " + "-" * 62) + + for hdim in HDIMS: + dQ, dK, dV, *_ = all_results[hdim] + l2_dq = float(np.sqrt((dQ**2).sum())) + l2_dk = float(np.sqrt((dK**2).sum())) + l2_dv = float(np.sqrt((dV**2).sum())) + ratio = l2_dq / (l2_dk + 1e-12) + print( + f" {hdim:>6} | {l2_dq:>12.4e} {l2_dk:>12.4e} {l2_dv:>12.4e} | {ratio:>12.2f}" + ) + + # --- Scale effect analysis --- + print("\n--- Scale Effect on Gradients ---") + print(" scale = 1/sqrt(hdim) -> larger hdim = smaller scale") + print(" This affects gradient magnitude through the dS = P * (dP - D) term.\n") + + print(f" {'hdim':>6} {'scale':>10} {'dQ max':>12} {'dK max':>12} {'dV max':>12}") + print(" " + "-" * 52) + + for hdim in HDIMS: + dQ, dK, dV, *_ = all_results[hdim] + scale = 1.0 / (hdim**0.5) + print( + f" {hdim:>6} {scale:>10.4f} {np.abs(dQ).max():>12.4e}" + f" {np.abs(dK).max():>12.4e} {np.abs(dV).max():>12.4e}" + ) + + # --- FP16 quantization impact per hdim --- + print("\n--- FP16 Backward Quantization Impact ---") + print( + f"\n {'hdim':>6} | {'dQ fp16 err':>12} {'dK fp16 err':>12} {'dV fp16 err':>12}" + ) + print(" " + "-" * 50) + + for hdim in HDIMS: + dQ, dK, dV, *_ = all_results[hdim] + dq_err = float(np.abs(dQ - dQ.astype(np.float16).astype(np.float32)).max()) + dk_err = float(np.abs(dK - dK.astype(np.float16).astype(np.float32)).max()) + dv_err = float(np.abs(dV - dV.astype(np.float16).astype(np.float32)).max()) + print(f" {hdim:>6} | {dq_err:>12.2e} {dk_err:>12.2e} {dv_err:>12.2e}") + + # --- Backward GPU API pattern per hdim --- + print("\n--- Backward GPU Kernel Config per hdim ---") + for hdim in HDIMS: + print(f"\n hdim={hdim}:") + print(" FmhaKernelConfig(") + print(" family='bwd', data_type='fp16',") + print(f" hdim_q={hdim}, hdim_v={hdim},") + print(f" tile_k0max={hdim}, tile_n1={hdim},") + print(f" tile_k0={min(32, hdim)}, tile_k1={min(32, hdim)},") + print(" )") + + # --- Summary --- + print("\n" + "=" * 70) + print(f" Head dims swept: {HDIMS}") + print(f" Fixed: B={BATCH} H={NHEAD} S={args.seqlen}") + print(" Scale effect: 1/sqrt(hdim) -> smaller gradients for larger hdim") + print(" Tile coupling: tile_k0max and tile_n1 must equal hdim") + print(" GPU: Each hdim needs a dedicated bwd-family kernel") + print(" Status: DEMO") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py index b0af5fa700..1467af8f0b 100644 --- a/dispatcher/examples/gemm/python/05_numpy_integration.py +++ b/dispatcher/examples/gemm/python/05_numpy_integration.py @@ -76,8 +76,6 @@ Examples: ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 05: NumPy Integration") print("=" * 60) diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py index 780032ce06..f1de50e34e 100644 --- a/dispatcher/examples/gemm/python/06_json_export.py +++ b/dispatcher/examples/gemm/python/06_json_export.py @@ -60,8 +60,6 @@ Examples: ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 06: JSON Export") print("=" * 60) diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py index 620e66eeaf..6065d94b49 100644 --- a/dispatcher/examples/gemm/python/07_stress_test.py +++ b/dispatcher/examples/gemm/python/07_stress_test.py @@ -40,7 +40,6 @@ from ctypes_utils import ( KernelConfig, setup_gemm_dispatcher, cleanup_gemm, - reset_for_example, Validator, detect_gpu_arch, ) @@ -418,8 +417,6 @@ Examples: ) args = parser.parse_args() - reset_for_example() - print("=" * 80) print("Example 07: GEMM Stress Test - Multiple Kernels") print("=" * 80) diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py index acbf1b3ae0..0cc50a0f23 100644 --- a/dispatcher/examples/gemm/python/08_heuristics.py +++ b/dispatcher/examples/gemm/python/08_heuristics.py @@ -566,8 +566,6 @@ Examples: ) args = parser.parse_args() - reset_for_example() - print("=" * 75) print("Example 08: Custom Heuristics") print("=" * 75) diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py index 5d9af239d4..2daa2295c3 100644 --- a/dispatcher/examples/gemm/python/09_multi_registry.py +++ b/dispatcher/examples/gemm/python/09_multi_registry.py @@ -56,8 +56,6 @@ Examples: ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 09: Multiple Registries") print("=" * 60) diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py index b1462478d0..01a56fcc27 100644 --- a/dispatcher/examples/gemm/python/10_advanced_benchmark.py +++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py @@ -95,8 +95,6 @@ def initialize_matrix(shape, method, dtype): def main(): args = parse_args() - reset_for_example() - print("=" * 70) print("Example 10: Advanced GEMM Benchmarking") print("=" * 70) diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py index d19395e553..4b4031539c 100644 --- a/dispatcher/examples/gemm/python/11_json_import.py +++ b/dispatcher/examples/gemm/python/11_json_import.py @@ -42,7 +42,6 @@ from ctypes_utils import ( # noqa: E402 KernelConfig as DispatcherKernelConfig, setup_gemm_dispatcher, cleanup_gemm, - reset_for_example, validate_kernel_config, detect_gpu_arch, ) @@ -146,8 +145,6 @@ Examples: ) args = parser.parse_args() - reset_for_example() - print_section("Example 11: JSON Kernel Configuration Import") # ========================================================================= diff --git a/dispatcher/include/ck_tile/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher.hpp index b3d8f10675..a0010b748f 100644 --- a/dispatcher/include/ck_tile/dispatcher.hpp +++ b/dispatcher/include/ck_tile/dispatcher.hpp @@ -7,6 +7,7 @@ /// For minimal includes, use the per-operation headers instead: /// ck_tile/dispatcher_gemm.hpp -- GEMM only /// ck_tile/dispatcher_conv.hpp -- Grouped Convolution only +/// ck_tile/dispatcher_fmha.hpp -- FMHA only // Core (needed by all ops) #include "ck_tile/dispatcher/base_registry.hpp" @@ -33,3 +34,13 @@ #include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" #include "ck_tile/dispatcher/grouped_conv_registry.hpp" #include "ck_tile/dispatcher/grouped_conv_utils.hpp" + +// FMHA +#include "ck_tile/dispatcher/fmha_types.hpp" +#include "ck_tile/dispatcher/fmha_problem.hpp" +#include "ck_tile/dispatcher/fmha_kernel_key.hpp" +#include "ck_tile/dispatcher/fmha_kernel_instance.hpp" +#include "ck_tile/dispatcher/fmha_registry.hpp" +#include "ck_tile/dispatcher/fmha_dispatcher.hpp" +#include "ck_tile/dispatcher/fmha_kernel_decl.hpp" +#include "ck_tile/dispatcher/backends/generated_fmha_backend.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_fmha_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_fmha_backend.hpp new file mode 100644 index 0000000000..600f950d19 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_fmha_backend.hpp @@ -0,0 +1,266 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/fmha_kernel_instance.hpp" + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +// mask_top_left(1) and mask_bottom_right(2) share the same compiled kernel +// (both use SimplifiedGenericAttentionMask). The actual mask +// coordinates are determined at runtime from the args, not the template. +inline bool fmha_mask_compatible(int kernel_mask, int problem_mask) +{ + if(kernel_mask == problem_mask) + return true; + // Both causal variants are served by the same kernel + constexpr int kTopLeft = 1; // mask_enum::mask_top_left + constexpr int kBottomRight = 2; // mask_enum::mask_bottom_right + if((kernel_mask == kTopLeft || kernel_mask == kBottomRight) && + (problem_mask == kTopLeft || problem_mask == kBottomRight)) + return true; + return false; +} + +inline bool fmha_signature_matches(const FmhaKernelKey& key, const FmhaProblem& problem) +{ + const auto& sig = key.signature; + const bool compare_page_size = + sig.family == FmhaKernelFamily::FwdPagedKv || + problem.requested_family == FmhaKernelFamily::FwdPagedKv || + sig.family == FmhaKernelFamily::FwdAppendKv || + problem.requested_family == FmhaKernelFamily::FwdAppendKv || + sig.family == FmhaKernelFamily::FwdSplitKv || + problem.requested_family == FmhaKernelFamily::FwdSplitKv || + sig.family == FmhaKernelFamily::FwdSplitKvCombine || + problem.requested_family == FmhaKernelFamily::FwdSplitKvCombine || + sig.family == FmhaKernelFamily::BatchPrefill || + problem.requested_family == FmhaKernelFamily::BatchPrefill; + const bool compare_kv_layout_lookup = + sig.family == FmhaKernelFamily::BatchPrefill || + problem.requested_family == FmhaKernelFamily::BatchPrefill; + + if(!(sig.family == problem.requested_family && sig.data_type == problem.data_type && + sig.is_group_mode == problem.is_group_mode && sig.is_v_rowmajor == problem.is_v_rowmajor && + sig.has_logits_soft_cap == problem.has_logits_soft_cap && + fmha_mask_compatible(sig.mask_type, problem.mask_type) && + sig.bias_type == problem.bias_type && sig.has_lse == problem.has_lse && + sig.has_dropout == problem.has_dropout && sig.qscale_type == problem.qscale_type && + sig.rope_type == problem.rope_type && sig.use_paged_kv == problem.use_paged_kv && + sig.do_fp8_static_quant == problem.do_fp8_static_quant && + sig.skip_min_seqlen_q == problem.skip_min_seqlen_q && sig.has_sink == problem.has_sink && + sig.has_dbias == problem.has_dbias && sig.is_store_randval == problem.is_store_randval && + sig.is_deterministic == problem.is_deterministic && problem.hdim_q <= sig.hdim_q && + problem.hdim_v <= sig.hdim_v)) + { + return false; + } + + if(compare_kv_layout_lookup) + { + if(sig.kv_memory_layout != problem.kv_memory_layout || + sig.kv_lookup_table != problem.kv_lookup_table) + { + return false; + } + } + + if(compare_page_size && sig.page_size > 1 && sig.page_size != problem.page_size) + { + return false; + } + + return true; +} + +inline bool fmha_algorithm_supports(const FmhaKernelKey& key, const FmhaProblem& problem) +{ + const auto& alg = key.algorithm; + + if(problem.is_group_mode && problem.max_seqlen_q <= 0) + { + return false; + } + + if(!alg.pad_s && alg.tile_shape.m0 > 0 && + problem.effective_max_seqlen_q() % alg.tile_shape.m0 != 0) + { + return false; + } + + if(!alg.pad_sk) + { + if(problem.has_variable_seqlen_k()) + { + return false; + } + if(alg.tile_shape.n0 > 0 && problem.effective_max_seqlen_k() % alg.tile_shape.n0 != 0) + { + return false; + } + } + + if(!alg.pad_d && alg.hdim_q_alignment > 0 && problem.hdim_q % alg.hdim_q_alignment != 0) + { + return false; + } + + if(!alg.pad_dv && alg.hdim_v_alignment > 0 && problem.hdim_v % alg.hdim_v_alignment != 0) + { + return false; + } + + if(alg.max_seq_len_q > 0 && problem.effective_max_seqlen_q() > alg.max_seq_len_q) + { + return false; + } + + if(alg.max_splits_log2 > 0 && + problem.num_splits > (static_cast(1) << alg.max_splits_log2)) + { + return false; + } + + return true; +} + +class GeneratedFmhaKernelInstance : public FmhaKernelInstance +{ + public: + using SupportsFn = std::function; + using LaunchFn = std::function; + using RunFn = std::function; + + GeneratedFmhaKernelInstance(FmhaKernelKey key, + std::string name, + SupportsFn supports_fn, + LaunchFn launch_fn, + RunFn run_fn = {}) + : key_(std::move(key)), + name_(std::move(name)), + supports_fn_(std::move(supports_fn)), + launch_fn_(std::move(launch_fn)), + run_fn_(std::move(run_fn)) + { + } + + [[nodiscard]] const FmhaKernelKey& get_key() const override { return key_; } + + [[nodiscard]] bool supports(const FmhaProblem& problem) const override + { + return supports_fn_ ? supports_fn_(problem) : false; + } + + [[nodiscard]] std::string get_name() const override { return name_; } + + void launch(const FmhaInvocation& invocation, + const ck_tile::stream_config& stream_config) const override + { + if(!launch_fn_) + { + throw std::runtime_error("FMHA kernel launch function is not available"); + } + launch_fn_(invocation, stream_config); + } + + [[nodiscard]] float run(const FmhaInvocation& invocation, + const ck_tile::stream_config& stream_config) const override + { + if(run_fn_) + { + return run_fn_(invocation, stream_config); + } + return FmhaKernelInstance::run(invocation, stream_config); + } + + private: + FmhaKernelKey key_; + std::string name_; + SupportsFn supports_fn_; + LaunchFn launch_fn_; + RunFn run_fn_; +}; + +inline GeneratedFmhaKernelInstance::SupportsFn +make_default_supports_fn(const FmhaKernelKey& key, + GeneratedFmhaKernelInstance::SupportsFn extra = {}) +{ + return [key, extra = std::move(extra)](const FmhaProblem& problem) { + if(!fmha_signature_matches(key, problem) || !fmha_algorithm_supports(key, problem)) + { + return false; + } + return extra ? extra(problem) : true; + }; +} + +template +inline FmhaKernelInstancePtr +make_oneshot_fmha_kernel(FmhaKernelKey key, + std::string name, + LaunchCallable&& launch_callable, + GeneratedFmhaKernelInstance::SupportsFn extra_support = {}) +{ + auto launch_fn = [launch_callable = std::forward(launch_callable)]( + const FmhaInvocation& invocation, const ck_tile::stream_config& sc) { + const auto* args = std::get_if(&invocation.args); + if(!args) + { + throw std::invalid_argument("FMHA invocation args do not match generated kernel type"); + } + launch_callable(sc, *args); + }; + + auto supports_fn = make_default_supports_fn(key, std::move(extra_support)); + return std::make_shared( + std::move(key), std::move(name), std::move(supports_fn), std::move(launch_fn)); +} + +template +inline FmhaKernelInstancePtr +make_timed_fmha_kernel(FmhaKernelKey key, + std::string name, + TimedCallable&& timed_callable, + GeneratedFmhaKernelInstance::SupportsFn extra_support = {}) +{ + auto callable = std::forward(timed_callable); + + auto launch_fn = [callable](const FmhaInvocation& invocation, + const ck_tile::stream_config& sc) { + const auto* args = std::get_if(&invocation.args); + if(!args) + { + throw std::invalid_argument("FMHA invocation args do not match generated kernel type"); + } + auto untimed = sc; + untimed.time_kernel_ = false; + (void)callable(untimed, *args); + }; + + auto run_fn = [callable](const FmhaInvocation& invocation, const ck_tile::stream_config& sc) { + const auto* args = std::get_if(&invocation.args); + if(!args) + { + throw std::invalid_argument("FMHA invocation args do not match generated kernel type"); + } + return callable(sc, *args); + }; + + auto supports_fn = make_default_supports_fn(key, std::move(extra_support)); + return std::make_shared(std::move(key), + std::move(name), + std::move(supports_fn), + std::move(launch_fn), + std::move(run_fn)); +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp index 79f8f30a9b..97734c1211 100644 --- a/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp @@ -101,14 +101,14 @@ class GeneratedKernelInstance : public KernelInstance problem.N // stride_E/C (row-major C: stride = N) ); - // Create stream config for timing + const bool bench = this->benchmarking_; ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = reinterpret_cast(stream); - stream_cfg.time_kernel_ = true; + stream_cfg.time_kernel_ = bench; stream_cfg.log_level_ = 0; - stream_cfg.cold_niters_ = 5; // Warmup iterations - stream_cfg.nrepeat_ = 10; // Measurement iterations - stream_cfg.is_gpu_timer_ = true; + stream_cfg.cold_niters_ = bench ? 5 : 0; + stream_cfg.nrepeat_ = bench ? 10 : 1; + stream_cfg.is_gpu_timer_ = bench; stream_cfg.flush_cache_ = false; stream_cfg.rotating_count_ = 1; diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp index 76565045cf..be22d94b33 100644 --- a/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -101,14 +101,14 @@ class GeneratedTileKernelInstance : public KernelInstance problem.N // stride_E/C (row-major C: stride = N) ); - // Create stream config for timing + const bool bench = this->benchmarking_; ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = reinterpret_cast(stream); - stream_cfg.time_kernel_ = true; - stream_cfg.log_level_ = 0; // No logging for performance - stream_cfg.cold_niters_ = 5; // Warmup iterations - stream_cfg.nrepeat_ = 10; // Measurement iterations - stream_cfg.is_gpu_timer_ = true; + stream_cfg.time_kernel_ = bench; + stream_cfg.log_level_ = 0; + stream_cfg.cold_niters_ = bench ? 5 : 0; + stream_cfg.nrepeat_ = bench ? 10 : 1; + stream_cfg.is_gpu_timer_ = bench; stream_cfg.flush_cache_ = false; stream_cfg.rotating_count_ = 1; diff --git a/dispatcher/include/ck_tile/dispatcher/example_args.hpp b/dispatcher/include/ck_tile/dispatcher/example_args.hpp index f93a4d61f6..17d0a3c0f3 100644 --- a/dispatcher/include/ck_tile/dispatcher/example_args.hpp +++ b/dispatcher/include/ck_tile/dispatcher/example_args.hpp @@ -3,11 +3,12 @@ #pragma once +#include #include -#include -#include #include #include +#include +#include #include namespace ck_tile { diff --git a/dispatcher/include/ck_tile/dispatcher/fmha_dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher/fmha_dispatcher.hpp new file mode 100644 index 0000000000..fba780159a --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/fmha_dispatcher.hpp @@ -0,0 +1,105 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/fmha_registry.hpp" + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +using FmhaHeuristicFunction = std::function(const FmhaProblem&)>; + +struct FmhaExecutionStage +{ + FmhaKernelFamily family = FmhaKernelFamily::Fwd; + std::string kernel_id; +}; + +struct FmhaExecutionPlan +{ + FmhaApiFamily api_family = FmhaApiFamily::Fwd; + std::vector stages; + + [[nodiscard]] bool is_valid() const { return !stages.empty(); } +}; + +class FmhaDispatcher +{ + public: + enum class SelectionStrategy + { + FirstFit, + Heuristic + }; + + explicit FmhaDispatcher(FmhaRegistry* registry = nullptr, const std::string& gfx_arch = ""); + + void set_heuristic(FmhaHeuristicFunction heuristic); + void set_strategy(SelectionStrategy strategy); + void set_timing(int cold_niters, int nrepeat); + void set_arch(const std::string& arch) { gfx_arch_ = arch; } + [[nodiscard]] const std::string& arch() const { return gfx_arch_; } + + [[nodiscard]] FmhaKernelInstancePtr select_kernel(const FmhaProblem& problem) const; + [[nodiscard]] FmhaExecutionPlan plan(const FmhaProblem& problem) const; + + [[nodiscard]] float run(const FmhaInvocation& invocation, void* stream = nullptr) const; + + [[nodiscard]] float run_explicit(const std::string& kernel_id, + const FmhaInvocation& invocation, + void* stream = nullptr) const; + + [[nodiscard]] float + run_fwd(fmha_fwd_traits traits, fmha_fwd_args args, void* stream = nullptr) const; + [[nodiscard]] float run_fwd_pagedkv(fmha_fwd_pagedkv_traits traits, + fmha_fwd_pagedkv_args args, + void* stream = nullptr) const; + [[nodiscard]] float run_fwd_splitkv(fmha_fwd_splitkv_traits traits, + fmha_fwd_splitkv_args args, + void* stream = nullptr) const; + [[nodiscard]] float run_fwd_appendkv(fmha_fwd_appendkv_traits traits, + fmha_fwd_appendkv_args args, + void* stream = nullptr) const; + [[nodiscard]] float run_batch_prefill(fmha_batch_prefill_traits traits, + fmha_batch_prefill_args args, + void* stream = nullptr) const; + // run_bwd is available when bwd types exist (library builds, bwd kernel TUs, + // or any TU that doesn't set CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE). + // In fwd-only TUs, bwd types come from the fallback in fmha_types.hpp. + [[nodiscard]] float + run_bwd(fmha_bwd_traits traits, fmha_bwd_args args, void* stream = nullptr) const; + + private: + [[nodiscard]] FmhaKernelInstancePtr select_first_fit(const FmhaProblem& problem) const; + [[nodiscard]] FmhaKernelInstancePtr select_heuristic(const FmhaProblem& problem) const; + + [[nodiscard]] FmhaProblem with_family(const FmhaProblem& base, FmhaKernelFamily family) const; + [[nodiscard]] FmhaExecutionPlan plan_single_stage(const FmhaProblem& problem, + FmhaKernelFamily family) const; + [[nodiscard]] float + run_plan(const FmhaExecutionPlan& plan, const FmhaInvocation& invocation, void* stream) const; + [[nodiscard]] ck_tile::stream_config make_stream_config(void* stream) const; + + FmhaRegistry* registry_; + FmhaHeuristicFunction heuristic_; + SelectionStrategy strategy_; + std::string gfx_arch_; + int cold_niters_ = 5; + int nrepeat_ = 10; + bool benchmarking_enabled_ = false; + + public: + /// Enable or disable benchmarking (GPU timing). + /// When disabled, kernels execute exactly once with no timing overhead + /// (one-shot mode for production plugins). + void set_benchmarking(bool enable) { benchmarking_enabled_ = enable; } + [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_enabled_; } +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/fmha_kernel_decl.hpp b/dispatcher/include/ck_tile/dispatcher/fmha_kernel_decl.hpp new file mode 100644 index 0000000000..7108c47e4b --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/fmha_kernel_decl.hpp @@ -0,0 +1,646 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace fmha_decl { + +constexpr const char* ANY = "*"; +constexpr int ANY_INT = -1; + +class FmhaSignature +{ + public: + std::string family_ = "fwd"; + std::string data_type_ = "fp16"; + std::string mode_ = "batch"; + std::string vlayout_ = "r"; + int hdim_q_ = 128; + int hdim_v_ = 128; + std::string mask_ = "no_mask"; + std::string bias_ = "no_bias"; + bool lse_ = false; + bool dropout_ = false; + std::string qscale_ = "no_scale"; + std::string rope_ = "none"; + bool logits_ = false; + bool paged_kv_ = false; + bool fp8_static_quant_ = false; + bool skip_min_seqlen_q_ = false; + bool sink_ = false; + bool dbias_ = false; + bool store_randval_ = false; + bool deterministic_ = false; + std::string kv_memory_layout_ = "vectorized"; + std::string kv_lookup_table_ = "sglang"; + int page_size_ = 1; + std::string profile_; + int receipt_ = -1; + + FmhaSignature& family(const std::string& family) + { + family_ = family; + return *this; + } + + FmhaSignature& dtype(const std::string& dtype) + { + data_type_ = dtype; + return *this; + } + + FmhaSignature& mode(const std::string& mode) + { + mode_ = mode; + return *this; + } + + FmhaSignature& vlayout(const std::string& layout) + { + vlayout_ = layout; + return *this; + } + + FmhaSignature& hdim(int q, int v = -1) + { + hdim_q_ = q; + hdim_v_ = (v < 0 ? q : v); + return *this; + } + + FmhaSignature& mask(const std::string& mask) + { + mask_ = mask; + return *this; + } + + FmhaSignature& bias(const std::string& bias) + { + bias_ = bias; + return *this; + } + + FmhaSignature& lse(bool value = true) + { + lse_ = value; + return *this; + } + + FmhaSignature& dropout(bool value = true) + { + dropout_ = value; + return *this; + } + + FmhaSignature& qscale(const std::string& qscale) + { + qscale_ = qscale; + return *this; + } + + FmhaSignature& rope(const std::string& rope) + { + rope_ = rope; + return *this; + } + + FmhaSignature& logits(bool value = true) + { + logits_ = value; + return *this; + } + + FmhaSignature& paged_kv(bool value = true) + { + paged_kv_ = value; + return *this; + } + + FmhaSignature& fp8_static_quant(bool value = true) + { + fp8_static_quant_ = value; + return *this; + } + + FmhaSignature& skip(bool value = true) + { + skip_min_seqlen_q_ = value; + return *this; + } + + FmhaSignature& sink(bool value = true) + { + sink_ = value; + return *this; + } + + FmhaSignature& dbias(bool value = true) + { + dbias_ = value; + return *this; + } + + FmhaSignature& store_randval(bool value = true) + { + store_randval_ = value; + return *this; + } + + FmhaSignature& deterministic(bool value = true) + { + deterministic_ = value; + return *this; + } + + FmhaSignature& + kv_cache(const std::string& memory_layout, const std::string& lookup_table, int page_size = 1) + { + kv_memory_layout_ = memory_layout; + kv_lookup_table_ = lookup_table; + page_size_ = page_size; + return *this; + } + + FmhaSignature& profile(const std::string& profile) + { + profile_ = profile; + return *this; + } + + FmhaSignature& receipt(int receipt) + { + receipt_ = receipt; + return *this; + } +}; + +class FmhaAlgorithm +{ + public: + int tile_m0_ = 128; + int tile_n0_ = 64; + int tile_k0_ = 32; + int tile_n1_ = 128; + int tile_k1_ = 32; + int tile_k0max_ = 128; + + int wave_m0_ = 2; + int wave_n0_ = 2; + int wave_k0_ = 1; + int wave_m1_ = 2; + int wave_n1_ = 2; + int wave_k1_ = 1; + int wave_m2_ = 1; + int wave_n2_ = 1; + int wave_k2_ = 1; + + int warp_m0_ = 32; + int warp_n0_ = 32; + int warp_k0_ = 16; + int warp_m1_ = 32; + int warp_n1_ = 32; + int warp_k1_ = 16; + int warp_m2_ = 16; + int warp_n2_ = 16; + int warp_k2_ = 16; + + std::string pipeline_ = "qr"; + bool pad_s_ = true; + bool pad_sk_ = true; + bool pad_d_ = true; + bool pad_dv_ = true; + bool use_trload_ = false; + int hdim_q_alignment_ = 0; + int hdim_v_alignment_ = 0; + int block_per_cu_ = 1; + int num_wave_groups_ = 1; + int max_splits_log2_ = 0; + int max_seq_len_q_ = 0; + int selection_rank_ = 0; + std::string constraint_tag_; + + // Bulk setters (positional, for backward compatibility) + FmhaAlgorithm& tile(int m0, int n0, int k0, int n1, int k1, int k0max) + { + tile_m0_ = m0; + tile_n0_ = n0; + tile_k0_ = k0; + tile_n1_ = n1; + tile_k1_ = k1; + tile_k0max_ = k0max; + return *this; + } + + FmhaAlgorithm& wave(int m0, + int n0, + int k0, + int m1 = 2, + int n1 = 2, + int k1 = 1, + int m2 = 1, + int n2 = 1, + int k2 = 1) + { + wave_m0_ = m0; + wave_n0_ = n0; + wave_k0_ = k0; + wave_m1_ = m1; + wave_n1_ = n1; + wave_k1_ = k1; + wave_m2_ = m2; + wave_n2_ = n2; + wave_k2_ = k2; + return *this; + } + + FmhaAlgorithm& warp(int m0, + int n0, + int k0, + int m1 = 32, + int n1 = 32, + int k1 = 16, + int m2 = 16, + int n2 = 16, + int k2 = 16) + { + warp_m0_ = m0; + warp_n0_ = n0; + warp_k0_ = k0; + warp_m1_ = m1; + warp_n1_ = n1; + warp_k1_ = k1; + warp_m2_ = m2; + warp_n2_ = n2; + warp_k2_ = k2; + return *this; + } + + // Named individual setters for clarity (preferred over positional bulk setters) + // Stage 0: Q * K^T (seqlen_q x seqlen_k x hdim_q) + FmhaAlgorithm& tile_m0(int v) + { + tile_m0_ = v; + return *this; + } + FmhaAlgorithm& tile_n0(int v) + { + tile_n0_ = v; + return *this; + } + FmhaAlgorithm& tile_k0(int v) + { + tile_k0_ = v; + return *this; + } + // Stage 1: Attn * V (seqlen_q x hdim_v x seqlen_k) + FmhaAlgorithm& tile_n1(int v) + { + tile_n1_ = v; + return *this; + } + FmhaAlgorithm& tile_k1(int v) + { + tile_k1_ = v; + return *this; + } + FmhaAlgorithm& tile_k0max(int v) + { + tile_k0max_ = v; + return *this; + } + + FmhaAlgorithm& wave_m0(int v) + { + wave_m0_ = v; + return *this; + } + FmhaAlgorithm& wave_n0(int v) + { + wave_n0_ = v; + return *this; + } + FmhaAlgorithm& wave_k0(int v) + { + wave_k0_ = v; + return *this; + } + FmhaAlgorithm& wave_m1(int v) + { + wave_m1_ = v; + return *this; + } + FmhaAlgorithm& wave_n1(int v) + { + wave_n1_ = v; + return *this; + } + FmhaAlgorithm& wave_k1(int v) + { + wave_k1_ = v; + return *this; + } + + FmhaAlgorithm& warp_m0(int v) + { + warp_m0_ = v; + return *this; + } + FmhaAlgorithm& warp_n0(int v) + { + warp_n0_ = v; + return *this; + } + FmhaAlgorithm& warp_k0(int v) + { + warp_k0_ = v; + return *this; + } + FmhaAlgorithm& warp_m1(int v) + { + warp_m1_ = v; + return *this; + } + FmhaAlgorithm& warp_n1(int v) + { + warp_n1_ = v; + return *this; + } + FmhaAlgorithm& warp_k1(int v) + { + warp_k1_ = v; + return *this; + } + + FmhaAlgorithm& pipeline(const std::string& pipeline) + { + pipeline_ = pipeline; + return *this; + } + + FmhaAlgorithm& padding(bool s, bool sk, bool d, bool dv) + { + pad_s_ = s; + pad_sk_ = sk; + pad_d_ = d; + pad_dv_ = dv; + return *this; + } + + FmhaAlgorithm& trload(bool value = true) + { + use_trload_ = value; + return *this; + } + + FmhaAlgorithm& alignments(int q_alignment, int v_alignment) + { + hdim_q_alignment_ = q_alignment; + hdim_v_alignment_ = v_alignment; + return *this; + } + + FmhaAlgorithm& block_per_cu(int value) + { + block_per_cu_ = value; + return *this; + } + + FmhaAlgorithm& num_wave_groups(int value) + { + num_wave_groups_ = value; + return *this; + } + + FmhaAlgorithm& max_splits_log2(int value) + { + max_splits_log2_ = value; + return *this; + } + + FmhaAlgorithm& max_seq_len_q(int value) + { + max_seq_len_q_ = value; + return *this; + } + + FmhaAlgorithm& selection_rank(int value) + { + selection_rank_ = value; + return *this; + } + + FmhaAlgorithm& constraint(const std::string& tag) + { + constraint_tag_ = tag; + return *this; + } + + void auto_fill() + { + if(tile_n1_ <= 0) + { + tile_n1_ = tile_n0_; + } + if(tile_k1_ <= 0) + { + tile_k1_ = tile_k0_; + } + if(tile_k0max_ <= 0) + { + tile_k0max_ = tile_k0_; + } + if(hdim_q_alignment_ <= 0) + { + hdim_q_alignment_ = tile_k0max_; + } + if(hdim_v_alignment_ <= 0) + { + hdim_v_alignment_ = tile_k0max_; + } + } +}; + +struct FmhaKernelDecl +{ + FmhaSignature signature; + FmhaAlgorithm algorithm; + std::string arch = "gfx942"; + + FmhaKernelDecl() = default; + FmhaKernelDecl(const FmhaSignature& sig, + const FmhaAlgorithm& algo, + const std::string& target_arch = "gfx942") + : signature(sig), algorithm(algo), arch(target_arch) + { + } + + std::string name() const + { + std::ostringstream oss; + oss << "fmha_" << signature.family_ << "_" << signature.data_type_ << "_" << signature.mode_ + << "_dq" << signature.hdim_q_ << "_dv" << signature.hdim_v_ << "_" << signature.vlayout_ + << "_" << algorithm.pipeline_; + return oss.str(); + } + + bool has_wildcards() const { return arch == "*"; } +}; + +class FmhaKernelSet +{ + public: + FmhaKernelSet() = default; + + FmhaKernelSet& + add(const FmhaSignature& sig, const FmhaAlgorithm& algo, const std::string& arch = "gfx942") + { + decls_.emplace_back(sig, algo, arch); + return *this; + } + + FmhaKernelSet& add(const FmhaKernelDecl& decl) + { + decls_.push_back(decl); + return *this; + } + + FmhaKernelSet& merge(const FmhaKernelSet& other) + { + decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end()); + return *this; + } + + const std::vector& declarations() const { return decls_; } + std::size_t size() const { return decls_.size(); } + + bool needs_expansion() const + { + for(const auto& d : decls_) + { + if(d.has_wildcards()) + return true; + } + return false; + } + + void print(std::ostream& os = std::cout) const + { + os << "FmhaKernelSet (" << size() << " declarations):\n"; + for(const auto& decl : decls_) + { + os << " - " << decl.name(); + if(decl.has_wildcards()) + os << " [expands]"; + os << "\n"; + } + } + + FmhaKernelSet& tag(const std::string& tag) + { + tag_ = tag; + return *this; + } + + const std::string& tag() const { return tag_; } + + private: + std::vector decls_; + std::string tag_; +}; + +/// Singleton registry for declarative kernel sets. +/// Thread safety: only populated during static initialization (pre-main) +/// via DECL_FMHA_KERNEL_SET macros. Do NOT call add() after main() starts. +class FmhaKernelSetRegistry +{ + public: + static FmhaKernelSetRegistry& instance() + { + static FmhaKernelSetRegistry registry; + return registry; + } + + void add(const std::string& name, const FmhaKernelSet& set) + { + sets_[name] = set; + if(std::find(order_.begin(), order_.end(), name) == order_.end()) + { + order_.push_back(name); + } + } + + const FmhaKernelSet& get(const std::string& name) const + { + static FmhaKernelSet empty; + auto it = sets_.find(name); + return it != sets_.end() ? it->second : empty; + } + + bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); } + + const std::vector& names() const { return order_; } + + std::size_t size() const { return sets_.size(); } + + void clear() + { + sets_.clear(); + order_.clear(); + } + + void print() const + { + std::cout << "FMHA Kernel Sets (" << sets_.size() << "):\n"; + for(const auto& name : order_) + { + const auto& set = sets_.at(name); + std::cout << " " << name << ": " << set.size() << " declarations\n"; + } + } + + private: + std::unordered_map sets_; + std::vector order_; +}; + +struct FmhaKernelSetRegistrar +{ + FmhaKernelSetRegistrar(const std::string& name, const FmhaKernelSet& set) + { + FmhaKernelSetRegistry::instance().add(name, set); + } +}; + +} // namespace fmha_decl + +using FmhaSignature = fmha_decl::FmhaSignature; +using FmhaAlgorithm = fmha_decl::FmhaAlgorithm; +using FmhaKernelDecl = fmha_decl::FmhaKernelDecl; +using FmhaKernelSet = fmha_decl::FmhaKernelSet; +using FmhaKernelSetRegistry = fmha_decl::FmhaKernelSetRegistry; + +} // namespace dispatcher +} // namespace ck_tile + +#define CK_FMHA_DECL_CAT_(a, b) CK_FMHA_DECL_CAT_IMPL_(a, b) +#define CK_FMHA_DECL_CAT_IMPL_(a, b) a##b + +#if defined(__GNUC__) || defined(__clang__) +#define CK_FMHA_DECL_EXT_ __extension__ +#else +#define CK_FMHA_DECL_EXT_ +#endif + +#define DECL_FMHA_KERNEL_SET(name, ...) \ + CK_FMHA_DECL_EXT_ static ::ck_tile::dispatcher::fmha_decl::FmhaKernelSetRegistrar \ + CK_FMHA_DECL_CAT_(_fmha_kset_reg_, __COUNTER__)( \ + #name, ::ck_tile::dispatcher::fmha_decl::FmhaKernelSet() __VA_ARGS__.tag(#name)) diff --git a/dispatcher/include/ck_tile/dispatcher/fmha_kernel_instance.hpp b/dispatcher/include/ck_tile/dispatcher/fmha_kernel_instance.hpp new file mode 100644 index 0000000000..5d24b615da --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/fmha_kernel_instance.hpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/fmha_kernel_key.hpp" +#include "ck_tile/dispatcher/fmha_problem.hpp" + +#include "ck_tile/host/kernel_launch.hpp" + +#include +#include + +namespace ck_tile { +namespace dispatcher { + +class FmhaKernelInstance +{ + public: + virtual ~FmhaKernelInstance() = default; + + [[nodiscard]] virtual const FmhaKernelKey& get_key() const = 0; + [[nodiscard]] virtual bool supports(const FmhaProblem& problem) const = 0; + [[nodiscard]] virtual std::string get_name() const = 0; + + // Short aliases (preferred for new code) + [[nodiscard]] const FmhaKernelKey& key() const { return get_key(); } + [[nodiscard]] std::string name() const { return get_name(); } + + virtual void launch(const FmhaInvocation& invocation, + const ck_tile::stream_config& stream_config) const = 0; + + [[nodiscard]] virtual float run(const FmhaInvocation& invocation, + const ck_tile::stream_config& stream_config) const + { + return ck_tile::launch_kernel( + stream_config, + [this, &invocation](const ck_tile::stream_config& sc) { launch(invocation, sc); }); + } +}; + +using FmhaKernelInstancePtr = std::shared_ptr; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/fmha_kernel_key.hpp b/dispatcher/include/ck_tile/dispatcher/fmha_kernel_key.hpp new file mode 100644 index 0000000000..b065ad7646 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/fmha_kernel_key.hpp @@ -0,0 +1,216 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/fmha_problem.hpp" + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +struct FmhaKernelKey +{ + // Runtime signature -- corresponds to fmha_decl::FmhaSignature (build-time). + // FmhaSignature uses strings for enums; Signature uses ints for matching speed. + // When adding fields here, also update FmhaSignature and tie(). + struct Signature + { + FmhaKernelFamily family = FmhaKernelFamily::Fwd; + std::string data_type; + bool is_group_mode = false; + bool is_v_rowmajor = true; + bool has_logits_soft_cap = false; + int mask_type = 0; + int bias_type = 0; + bool has_lse = false; + bool has_dropout = false; + int qscale_type = 0; + int rope_type = 0; + bool use_paged_kv = false; + bool do_fp8_static_quant = false; + bool skip_min_seqlen_q = false; + bool has_sink = false; + bool has_dbias = false; + bool is_store_randval = false; + bool is_deterministic = false; + int kv_memory_layout = 0; + int kv_lookup_table = 0; + int page_size = 1; + std::uint16_t hdim_q = 0; + std::uint16_t hdim_v = 0; + int receipt = -1; + } signature; + + struct Algorithm + { + struct TileShape + { + std::uint16_t m0 = 0; + std::uint16_t n0 = 0; + std::uint16_t k0 = 0; + std::uint16_t n1 = 0; + std::uint16_t k1 = 0; + std::uint16_t k0max = 0; + } tile_shape; + + struct WaveShape + { + std::uint8_t m0 = 1; + std::uint8_t n0 = 1; + std::uint8_t k0 = 1; + std::uint8_t m1 = 1; + std::uint8_t n1 = 1; + std::uint8_t k1 = 1; + std::uint8_t m2 = 1; + std::uint8_t n2 = 1; + std::uint8_t k2 = 1; + } wave_shape; + + struct WarpTileShape + { + std::uint16_t m0 = 0; + std::uint16_t n0 = 0; + std::uint16_t k0 = 0; + std::uint16_t m1 = 0; + std::uint16_t n1 = 0; + std::uint16_t k1 = 0; + std::uint16_t m2 = 0; + std::uint16_t n2 = 0; + std::uint16_t k2 = 0; + } warp_tile_shape; + + std::string pipeline; + bool pad_s = true; + bool pad_sk = true; + bool pad_d = true; + bool pad_dv = true; + bool use_trload = false; + std::uint8_t block_per_cu = 1; + std::uint8_t num_wave_groups = 1; + std::uint8_t max_splits_log2 = 0; + std::uint16_t max_seq_len_q = 0; + std::uint16_t hdim_q_alignment = 0; + std::uint16_t hdim_v_alignment = 0; + std::int32_t selection_rank = 0; + std::string constraint_tag; + } algorithm; + + std::string gfx_arch; + + [[nodiscard]] std::string encode_identifier() const + { + std::ostringstream oss; + oss << "fmha_" << to_string(signature.family) << "_" << signature.data_type << "_" + << (signature.is_group_mode ? "group" : "batch") << "_" + << (signature.is_v_rowmajor ? "vr" : "vc") << "_hq" << signature.hdim_q << "_hv" + << signature.hdim_v << "_p" << algorithm.pipeline << "_m" << signature.mask_type << "_b" + << signature.bias_type << "_lse" << signature.has_lse << "_do" << signature.has_dropout + << "_qs" << signature.qscale_type << "_rp" << signature.rope_type << "_pkv" + << signature.use_paged_kv << "_sq" << signature.do_fp8_static_quant << "_sk" + << signature.skip_min_seqlen_q << "_sink" << signature.has_sink << "_db" + << signature.has_dbias << "_sr" << signature.is_store_randval << "_det" + << signature.is_deterministic << "_km" << signature.kv_memory_layout << "_kl" + << signature.kv_lookup_table << "_ps" << signature.page_size << "_t" + << algorithm.tile_shape.m0 << "x" << algorithm.tile_shape.n0 << "x" + << algorithm.tile_shape.k0 << "x" << algorithm.tile_shape.n1 << "x" + << algorithm.tile_shape.k1 << "x" << algorithm.tile_shape.k0max << "_w0" + << unsigned(algorithm.wave_shape.m0) << "x" << unsigned(algorithm.wave_shape.n0) << "x" + << unsigned(algorithm.wave_shape.k0) << "_w1" << unsigned(algorithm.wave_shape.m1) + << "x" << unsigned(algorithm.wave_shape.n1) << "x" << unsigned(algorithm.wave_shape.k1) + << "_wt0" << algorithm.warp_tile_shape.m0 << "x" << algorithm.warp_tile_shape.n0 << "x" + << algorithm.warp_tile_shape.k0 << "_wt1" << algorithm.warp_tile_shape.m1 << "x" + << algorithm.warp_tile_shape.n1 << "x" << algorithm.warp_tile_shape.k1 << "_pads" + << algorithm.pad_s << algorithm.pad_sk << algorithm.pad_d << algorithm.pad_dv << "_tr" + << algorithm.use_trload << "_bpc" << unsigned(algorithm.block_per_cu) << "_wg" + << unsigned(algorithm.num_wave_groups) << "_ms" << unsigned(algorithm.max_splits_log2) + << "_mq" << algorithm.max_seq_len_q << "_aq" << algorithm.hdim_q_alignment << "_av" + << algorithm.hdim_v_alignment << "_r" << algorithm.selection_rank << "_rc" + << signature.receipt; + return oss.str(); + } + + auto tie() const + { + return std::tie(signature.family, + signature.data_type, + signature.is_group_mode, + signature.is_v_rowmajor, + signature.has_logits_soft_cap, + signature.mask_type, + signature.bias_type, + signature.has_lse, + signature.has_dropout, + signature.qscale_type, + signature.rope_type, + signature.use_paged_kv, + signature.do_fp8_static_quant, + signature.skip_min_seqlen_q, + signature.has_sink, + signature.has_dbias, + signature.is_store_randval, + signature.is_deterministic, + signature.kv_memory_layout, + signature.kv_lookup_table, + signature.page_size, + signature.hdim_q, + signature.hdim_v, + algorithm.tile_shape.m0, + algorithm.tile_shape.n0, + algorithm.tile_shape.k0, + algorithm.tile_shape.n1, + algorithm.tile_shape.k1, + algorithm.tile_shape.k0max, + algorithm.wave_shape.m0, + algorithm.wave_shape.n0, + algorithm.wave_shape.k0, + algorithm.wave_shape.m1, + algorithm.wave_shape.n1, + algorithm.wave_shape.k1, + algorithm.wave_shape.m2, + algorithm.wave_shape.n2, + algorithm.wave_shape.k2, + algorithm.warp_tile_shape.m0, + algorithm.warp_tile_shape.n0, + algorithm.warp_tile_shape.k0, + algorithm.warp_tile_shape.m1, + algorithm.warp_tile_shape.n1, + algorithm.warp_tile_shape.k1, + algorithm.warp_tile_shape.m2, + algorithm.warp_tile_shape.n2, + algorithm.warp_tile_shape.k2, + algorithm.pipeline, + algorithm.pad_s, + algorithm.pad_sk, + algorithm.pad_d, + algorithm.pad_dv, + algorithm.use_trload, + algorithm.block_per_cu, + algorithm.num_wave_groups, + algorithm.max_splits_log2, + algorithm.max_seq_len_q, + algorithm.hdim_q_alignment, + algorithm.hdim_v_alignment, + algorithm.selection_rank, + algorithm.constraint_tag, + gfx_arch, + signature.receipt); + } + + friend bool operator==(const FmhaKernelKey& lhs, const FmhaKernelKey& rhs) + { + return lhs.tie() == rhs.tie(); + } + + friend bool operator!=(const FmhaKernelKey& lhs, const FmhaKernelKey& rhs) + { + return !(lhs == rhs); + } +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/fmha_problem.hpp b/dispatcher/include/ck_tile/dispatcher/fmha_problem.hpp new file mode 100644 index 0000000000..0eca65a48b --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/fmha_problem.hpp @@ -0,0 +1,794 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/fmha_types.hpp" + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +enum class FmhaApiFamily : std::uint8_t +{ + Fwd, + FwdPagedKv, + FwdSplitKv, + FwdAppendKv, + BatchPrefill, + Bwd +}; + +enum class FmhaKernelFamily : std::uint8_t +{ + Fwd, + FwdPagedKv, + FwdSplitKv, + FwdSplitKvCombine, + FwdAppendKv, + BatchPrefill, + BwdDotDoO, + BwdDqDkDv, + BwdConvertDq +}; + +inline std::string to_string(FmhaApiFamily family) +{ + switch(family) + { + case FmhaApiFamily::Fwd: return "fwd"; + case FmhaApiFamily::FwdPagedKv: return "fwd_pagedkv"; + case FmhaApiFamily::FwdSplitKv: return "fwd_splitkv"; + case FmhaApiFamily::FwdAppendKv: return "fwd_appendkv"; + case FmhaApiFamily::BatchPrefill: return "batch_prefill"; + case FmhaApiFamily::Bwd: return "bwd"; + default: return "unknown"; + } +} + +inline std::string to_string(FmhaKernelFamily family) +{ + switch(family) + { + case FmhaKernelFamily::Fwd: return "fwd"; + case FmhaKernelFamily::FwdPagedKv: return "fwd_pagedkv"; + case FmhaKernelFamily::FwdSplitKv: return "fwd_splitkv"; + case FmhaKernelFamily::FwdSplitKvCombine: return "fwd_splitkv_combine"; + case FmhaKernelFamily::FwdAppendKv: return "fwd_appendkv"; + case FmhaKernelFamily::BatchPrefill: return "batch_prefill"; + case FmhaKernelFamily::BwdDotDoO: return "bwd_dot_do_o"; + case FmhaKernelFamily::BwdDqDkDv: return "bwd_dq_dk_dv"; + case FmhaKernelFamily::BwdConvertDq: return "bwd_convert_dq"; + default: return "unknown"; + } +} + +// Combined variants containing both forward and backward types. +// Both fwd and bwd types are always available via fallback definitions +// in fmha_types.hpp (they are conditionally guarded but the fallback +// provides them when the example headers don't). +using FmhaTraitsVariant = std::variant; + +using FmhaArgsVariant = std::variant; + +struct FmhaInvocation +{ + FmhaApiFamily api_family = FmhaApiFamily::Fwd; + FmhaTraitsVariant traits; + FmhaArgsVariant args; + + static FmhaInvocation make(fmha_fwd_traits t, fmha_fwd_args a) + { + return {FmhaApiFamily::Fwd, std::move(t), std::move(a)}; + } + + static FmhaInvocation make(fmha_fwd_pagedkv_traits t, fmha_fwd_pagedkv_args a) + { + return {FmhaApiFamily::FwdPagedKv, std::move(t), std::move(a)}; + } + + static FmhaInvocation make(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a) + { + return {FmhaApiFamily::FwdSplitKv, std::move(t), std::move(a)}; + } + + static FmhaInvocation make(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a) + { + return {FmhaApiFamily::FwdAppendKv, std::move(t), std::move(a)}; + } + + static FmhaInvocation make(fmha_batch_prefill_traits t, fmha_batch_prefill_args a) + { + return {FmhaApiFamily::BatchPrefill, std::move(t), std::move(a)}; + } + + static FmhaInvocation make(fmha_bwd_traits t, fmha_bwd_args a) + { + return {FmhaApiFamily::Bwd, std::move(t), std::move(a)}; + } +}; + +struct FmhaProblem +{ + FmhaApiFamily api_family = FmhaApiFamily::Fwd; + FmhaKernelFamily requested_family = FmhaKernelFamily::Fwd; + std::string gfx_arch; + std::string data_type; + + bool is_group_mode = false; + bool is_v_rowmajor = true; + bool has_logits_soft_cap = false; + int mask_type = 0; + int bias_type = 0; + bool has_lse = false; + bool has_dropout = false; + int qscale_type = 0; + int rope_type = 0; + bool use_paged_kv = false; + bool do_fp8_static_quant = false; + bool skip_min_seqlen_q = false; + bool has_sink = false; + bool has_dbias = false; + bool is_store_randval = false; + bool is_deterministic = false; + int kv_memory_layout = 0; + int kv_lookup_table = 0; + int page_size = 1; + + std::int64_t seqlen_q = 0; + std::int64_t seqlen_k = 0; + std::int64_t max_seqlen_q = 0; + std::int64_t max_seqlen_k = 0; + std::int64_t batch = 0; + std::int64_t hdim_q = 0; + std::int64_t hdim_v = 0; + std::int64_t nhead_q = 0; + std::int64_t nhead_k = 0; + std::int64_t num_splits = 1; + std::int64_t window_size_left = 0; + std::int64_t window_size_right = 0; + std::int64_t sink_size = 0; + std::int64_t min_seqlen_q = 0; + std::int64_t rotary_dim = 0; + + bool has_seqstart_q_ptr = false; + bool has_seqstart_k_ptr = false; + bool has_seqlen_q_ptr = false; + bool has_seqlen_k_ptr = false; + bool has_cu_seqlen_q_ptr = false; + bool has_cu_seqlen_k_ptr = false; + bool has_block_table_ptr = false; + bool has_cache_batch_idx = false; + bool is_gappy = false; + bool has_rotary_cos_sin = false; + + [[nodiscard]] bool is_valid() const + { + return !data_type.empty() && batch > 0 && hdim_q > 0 && hdim_v > 0 && nhead_q > 0 && + nhead_k > 0; + } + + [[nodiscard]] std::int64_t effective_max_seqlen_q() const + { + return max_seqlen_q > 0 ? max_seqlen_q : seqlen_q; + } + + [[nodiscard]] std::int64_t effective_max_seqlen_k() const + { + return max_seqlen_k > 0 ? max_seqlen_k : seqlen_k; + } + + [[nodiscard]] bool has_variable_seqlen_q() const + { + return has_seqstart_q_ptr || has_seqlen_q_ptr || has_cu_seqlen_q_ptr; + } + + [[nodiscard]] bool has_variable_seqlen_k() const + { + return has_seqstart_k_ptr || has_seqlen_k_ptr || has_cu_seqlen_k_ptr || is_gappy; + } + + [[nodiscard]] std::uint64_t num_ops() const + { + const auto sq = effective_max_seqlen_q(); + const auto sk = effective_max_seqlen_k(); + if(batch <= 0 || nhead_q <= 0 || sq <= 0 || sk <= 0 || hdim_q <= 0 || hdim_v <= 0) + return 0; + return 2ULL * static_cast(batch) * static_cast(nhead_q) * + static_cast(sq) * static_cast(sk) * + static_cast(hdim_q + hdim_v); + } + + [[nodiscard]] std::string to_string() const + { + std::string s; + s += "FmhaProblem("; + s += "api=" + ck_tile::dispatcher::to_string(api_family); + s += ", family=" + ck_tile::dispatcher::to_string(requested_family); + s += ", dtype=" + data_type; + s += ", arch=" + gfx_arch; + s += ", batch=" + std::to_string(batch); + s += ", sq=" + std::to_string(seqlen_q); + s += ", sk=" + std::to_string(seqlen_k); + s += ", dq=" + std::to_string(hdim_q); + s += ", dv=" + std::to_string(hdim_v); + s += ", hq=" + std::to_string(nhead_q); + s += ", hk=" + std::to_string(nhead_k); + s += ", group=" + std::string(is_group_mode ? "y" : "n"); + s += ", mask=" + std::to_string(mask_type); + s += ", bias=" + std::to_string(bias_type); + s += ")"; + return s; + } + + /// Canonical key for caching -- includes ALL fields used by fmha_signature_matches(). + /// Safe to use as a cache key (unlike to_string() which omits many fields). + [[nodiscard]] std::string canonical_key() const + { + constexpr char S = '\x1f'; // ASCII unit separator -- unambiguous delimiter + std::string k; + k.reserve(256); + k += ck_tile::dispatcher::to_string(api_family); + k += S; + k += ck_tile::dispatcher::to_string(requested_family); + k += S; + k += data_type; + k += S; + k += gfx_arch; + k += S; + k += std::to_string(hdim_q); + k += ','; + k += std::to_string(hdim_v); + k += S; + k += is_group_mode ? '1' : '0'; + k += is_v_rowmajor ? '1' : '0'; + k += has_logits_soft_cap ? '1' : '0'; + k += has_lse ? '1' : '0'; + k += has_dropout ? '1' : '0'; + k += use_paged_kv ? '1' : '0'; + k += do_fp8_static_quant ? '1' : '0'; + k += skip_min_seqlen_q ? '1' : '0'; + k += has_sink ? '1' : '0'; + k += has_dbias ? '1' : '0'; + k += is_store_randval ? '1' : '0'; + k += is_deterministic ? '1' : '0'; + k += S; + k += std::to_string(mask_type); + k += ','; + k += std::to_string(bias_type); + k += ','; + k += std::to_string(qscale_type); + k += ','; + k += std::to_string(rope_type); + k += S; + k += std::to_string(kv_memory_layout); + k += ','; + k += std::to_string(kv_lookup_table); + k += ','; + k += std::to_string(page_size); + return k; + } + + [[nodiscard]] static FmhaProblem from_invocation(const FmhaInvocation& invocation, + const std::string& gfx_arch = "") + { + FmhaProblem p; + p.api_family = invocation.api_family; + p.gfx_arch = gfx_arch; + + std::visit( + [&](const auto& traits) { + using T = std::decay_t; + + if constexpr(std::is_same_v) + { + p.requested_family = FmhaKernelFamily::Fwd; + p.data_type = traits.data_type; + p.is_group_mode = traits.is_group_mode; + p.is_v_rowmajor = traits.is_v_rowmajor; + p.has_logits_soft_cap = traits.has_logits_soft_cap; + p.mask_type = static_cast(traits.mask_type); + p.bias_type = static_cast(traits.bias_type); + p.has_lse = traits.has_lse; + p.has_dropout = traits.has_dropout; + p.qscale_type = static_cast(traits.qscale_type); + p.skip_min_seqlen_q = traits.skip_min_seqlen_q; + p.has_sink = traits.has_sink; + p.hdim_q = traits.hdim_q; + p.hdim_v = traits.hdim_v; + } + else if constexpr(std::is_same_v) + { + p.requested_family = FmhaKernelFamily::FwdPagedKv; + p.data_type = traits.data_type; + p.is_group_mode = traits.is_group_mode; + p.is_v_rowmajor = traits.is_v_rowmajor; + p.has_logits_soft_cap = traits.has_logits_soft_cap; + p.mask_type = static_cast(traits.mask_type); + p.bias_type = static_cast(traits.bias_type); + p.has_lse = traits.has_lse; + p.use_paged_kv = traits.use_pagedkv; + p.do_fp8_static_quant = traits.do_fp8_static_quant; + p.skip_min_seqlen_q = traits.skip_min_seqlen_q; + p.has_sink = traits.has_sink; + p.hdim_q = traits.hdim_q; + p.hdim_v = traits.hdim_v; + } + else if constexpr(std::is_same_v) + { + p.requested_family = FmhaKernelFamily::FwdSplitKv; + p.data_type = traits.data_type; + p.is_group_mode = traits.is_group_mode; + p.is_v_rowmajor = traits.is_v_rowmajor; + p.has_logits_soft_cap = traits.has_logits_soft_cap; + p.mask_type = static_cast(traits.mask_type); + p.bias_type = static_cast(traits.bias_type); + p.has_lse = traits.has_lse; + p.do_fp8_static_quant = traits.do_fp8_static_quant; + p.has_sink = traits.has_sink; + p.hdim_q = traits.hdim_q; + p.hdim_v = traits.hdim_v; + // Explicit defaults for fields not in splitkv traits + p.has_dropout = false; + p.skip_min_seqlen_q = false; + p.use_paged_kv = false; + p.has_dbias = false; + p.is_store_randval = false; + p.is_deterministic = false; + } + else if constexpr(std::is_same_v) + { + p.requested_family = FmhaKernelFamily::FwdAppendKv; + p.data_type = traits.data_type; + p.is_group_mode = false; + p.is_v_rowmajor = traits.is_v_rowmajor; + p.rope_type = static_cast(traits.rope_type); + p.hdim_q = traits.hdim_q; + p.hdim_v = traits.hdim_v; + // Explicit defaults for fields not in appendkv traits + p.has_logits_soft_cap = false; + p.mask_type = 0; + p.bias_type = 0; + p.has_lse = false; + p.has_dropout = false; + p.has_sink = false; + p.skip_min_seqlen_q = false; + p.use_paged_kv = false; + p.has_dbias = false; + p.is_store_randval = false; + p.is_deterministic = false; + } + else if constexpr(std::is_same_v) + { + p.requested_family = FmhaKernelFamily::BatchPrefill; + p.data_type = traits.data_type; + p.is_group_mode = traits.is_group_mode; + p.is_v_rowmajor = traits.is_v_rowmajor; + p.has_logits_soft_cap = traits.has_logits_soft_cap; + p.mask_type = static_cast(traits.mask_type); + p.bias_type = static_cast(traits.bias_type); + p.has_lse = traits.has_lse; + p.has_dropout = traits.has_dropout; + p.qscale_type = static_cast(traits.qscale_type); + p.skip_min_seqlen_q = traits.skip_min_seqlen_q; + p.has_sink = traits.has_sink; + p.kv_memory_layout = static_cast(traits.kv_memory_layout); + p.kv_lookup_table = static_cast(traits.kv_lookup_table); + p.page_size = traits.page_size; + p.use_paged_kv = true; + p.hdim_q = traits.hdim_q; + p.hdim_v = traits.hdim_v; + } + else if constexpr(std::is_same_v) + { + p.requested_family = FmhaKernelFamily::BwdDqDkDv; + p.seqlen_q = traits.seqlen_q; + p.seqlen_k = traits.seqlen_k; + p.batch = traits.batch; + p.max_seqlen_q = traits.max_seqlen_q; + p.max_seqlen_k = traits.max_seqlen_k; + p.hdim_q = traits.hdim_q; + p.hdim_v = traits.hdim_v; + p.nhead_q = traits.nhead_q; + p.nhead_k = traits.nhead_k; + p.data_type = traits.data_type; + p.is_group_mode = traits.is_group_mode; + p.mask_type = static_cast(traits.mask_type); + p.bias_type = static_cast(traits.bias_type); + p.has_dbias = traits.has_dbias; + p.has_dropout = traits.has_dropout; + p.is_store_randval = traits.is_store_randval; + p.is_deterministic = traits.is_deterministic; + // Explicit defaults for fields not in bwd traits + p.is_v_rowmajor = true; + p.has_logits_soft_cap = false; + p.has_lse = false; + p.has_sink = false; + p.skip_min_seqlen_q = false; + p.use_paged_kv = false; + } + }, + invocation.traits); + + std::visit( + [&](const auto& args) { + using T = std::decay_t; + + if constexpr(std::is_same_v) + { + p.seqlen_q = args.seqlen_q; + p.seqlen_k = args.seqlen_k; + p.batch = args.batch; + p.max_seqlen_q = args.max_seqlen_q; + p.nhead_q = args.nhead_q; + p.nhead_k = args.nhead_k; + p.window_size_left = args.window_size_left; + p.window_size_right = args.window_size_right; + p.sink_size = args.sink_size; + p.min_seqlen_q = args.min_seqlen_q; + p.has_seqstart_q_ptr = args.seqstart_q_ptr != nullptr; + p.has_seqstart_k_ptr = args.seqstart_k_ptr != nullptr; + p.has_seqlen_q_ptr = args.seqlen_q_ptr != nullptr; + p.has_seqlen_k_ptr = args.seqlen_k_ptr != nullptr; + p.has_cu_seqlen_q_ptr = args.cu_seqlen_q_ptr != nullptr; + p.has_cu_seqlen_k_ptr = args.cu_seqlen_k_ptr != nullptr; + } + else if constexpr(std::is_same_v) + { + p.seqlen_q = args.seqlen_q; + p.seqlen_k = args.seqlen_k; + p.batch = args.batch; + p.max_seqlen_q = args.max_seqlen_q; + p.nhead_q = args.nhead_q; + p.nhead_k = args.nhead_k; + p.page_size = args.page_block_size; + p.window_size_left = args.window_size_left; + p.window_size_right = args.window_size_right; + p.sink_size = args.sink_size; + p.min_seqlen_q = args.min_seqlen_q; + p.has_seqstart_q_ptr = args.seqstart_q_ptr != nullptr; + p.has_seqstart_k_ptr = args.seqstart_k_ptr != nullptr; + p.has_seqlen_k_ptr = args.seqlen_k_ptr != nullptr; + p.has_block_table_ptr = args.block_table_ptr != nullptr; + p.has_cache_batch_idx = args.cache_batch_idx != nullptr; + p.is_gappy = args.is_gappy; + } + else if constexpr(std::is_same_v) + { + p.seqlen_q = args.seqlen_q; + p.seqlen_k = args.seqlen_k; + p.batch = args.batch; + p.max_seqlen_q = args.max_seqlen_q; + p.nhead_q = args.nhead_q; + p.nhead_k = args.nhead_k; + p.num_splits = args.num_splits; + p.page_size = args.page_block_size; + p.window_size_left = args.window_size_left; + p.window_size_right = args.window_size_right; + p.sink_size = args.sink_size; + p.has_seqstart_q_ptr = args.seqstart_q_ptr != nullptr; + p.has_seqstart_k_ptr = args.seqstart_k_ptr != nullptr; + p.has_seqlen_k_ptr = args.seqlen_k_ptr != nullptr; + p.has_block_table_ptr = args.block_table_ptr != nullptr; + p.has_cache_batch_idx = args.cache_batch_idx != nullptr; + p.is_gappy = args.is_gappy; + p.use_paged_kv = args.block_table_ptr != nullptr; + } + else if constexpr(std::is_same_v) + { + p.seqlen_q = args.seqlen_q; + p.seqlen_k = args.seqlen_knew; + p.batch = args.batch; + p.nhead_q = args.nhead_q; + p.nhead_k = args.nhead_k; + p.page_size = args.page_block_size; + p.rotary_dim = args.rotary_dim; + p.has_seqlen_k_ptr = args.seqlen_k_ptr != nullptr; + p.has_block_table_ptr = args.block_table_ptr != nullptr; + p.has_cache_batch_idx = args.cache_batch_idx != nullptr; + p.has_rotary_cos_sin = + args.rotary_cos_ptr != nullptr && args.rotary_sin_ptr != nullptr; + p.use_paged_kv = args.block_table_ptr != nullptr; + } + else if constexpr(std::is_same_v) + { + p.seqlen_q = args.seqlen_q; + p.seqlen_k = args.seqlen_k; + p.batch = args.batch; + p.max_seqlen_q = args.max_seqlen_q; + p.nhead_q = args.nhead_q; + p.nhead_k = args.nhead_k; + p.page_size = args.page_block_size; + p.kv_memory_layout = static_cast(args.kv_memory_layout); + p.kv_lookup_table = static_cast(args.kv_lookup_table); + p.window_size_left = args.window_size_left; + p.window_size_right = args.window_size_right; + p.sink_size = args.sink_size; + p.has_seqstart_q_ptr = args.seqstart_q_ptr != nullptr; + p.has_seqlen_k_ptr = args.seqlen_k_ptr != nullptr; + p.use_paged_kv = true; + } + else if constexpr(std::is_same_v) + { + p.seqlen_q = args.seqlen_q; + p.seqlen_k = args.seqlen_k; + p.batch = args.batch; + p.max_seqlen_q = args.max_seqlen_q; + p.max_seqlen_k = args.max_seqlen_k; + p.nhead_q = args.nhead_q; + p.nhead_k = args.nhead_k; + p.window_size_left = args.window_size_left; + p.window_size_right = args.window_size_right; + p.has_seqstart_q_ptr = args.seqstart_q_ptr != nullptr; + p.has_seqstart_k_ptr = args.seqstart_k_ptr != nullptr; + p.has_seqlen_q_ptr = args.seqlen_q_ptr != nullptr; + p.has_seqlen_k_ptr = args.seqlen_k_ptr != nullptr; + p.has_cu_seqlen_q_ptr = args.cu_seqlen_q_ptr != nullptr; + p.has_cu_seqlen_k_ptr = args.cu_seqlen_k_ptr != nullptr; + } + }, + invocation.args); + + return p; + } +}; + +class FmhaProblemBuilder +{ + public: + FmhaProblemBuilder() = default; + + FmhaProblemBuilder& api_family(FmhaApiFamily family) + { + problem_.api_family = family; + return *this; + } + + FmhaProblemBuilder& kernel_family(FmhaKernelFamily family) + { + problem_.requested_family = family; + return *this; + } + + FmhaProblemBuilder& gfx_arch(const std::string& arch) + { + problem_.gfx_arch = arch; + return *this; + } + + FmhaProblemBuilder& data_type(const std::string& dtype) + { + problem_.data_type = dtype; + return *this; + } + + FmhaProblemBuilder& dims(std::int64_t hdim_q, + std::int64_t hdim_v, + std::int64_t batch, + std::int64_t seqlen_q, + std::int64_t seqlen_k) + { + problem_.hdim_q = hdim_q; + problem_.hdim_v = hdim_v; + problem_.batch = batch; + problem_.seqlen_q = seqlen_q; + problem_.seqlen_k = seqlen_k; + return *this; + } + + FmhaProblemBuilder& nheads(std::int64_t q, std::int64_t k) + { + problem_.nhead_q = q; + problem_.nhead_k = k; + return *this; + } + + FmhaProblemBuilder& mask_type(int mask) + { + problem_.mask_type = mask; + return *this; + } + + FmhaProblemBuilder& bias_type(int bias) + { + problem_.bias_type = bias; + return *this; + } + + FmhaProblemBuilder& lse(bool value) + { + problem_.has_lse = value; + return *this; + } + + FmhaProblemBuilder& dropout(bool value) + { + problem_.has_dropout = value; + return *this; + } + + FmhaProblemBuilder& qscale_type(int qscale) + { + problem_.qscale_type = qscale; + return *this; + } + + FmhaProblemBuilder& rope_type(int rope) + { + problem_.rope_type = rope; + return *this; + } + + FmhaProblemBuilder& logits_soft_cap(bool value) + { + problem_.has_logits_soft_cap = value; + return *this; + } + + FmhaProblemBuilder& v_rowmajor(bool value) + { + problem_.is_v_rowmajor = value; + return *this; + } + + FmhaProblemBuilder& group_mode(bool value) + { + problem_.is_group_mode = value; + return *this; + } + + FmhaProblemBuilder& paged_kv(bool value) + { + problem_.use_paged_kv = value; + return *this; + } + + FmhaProblemBuilder& fp8_static_quant(bool value) + { + problem_.do_fp8_static_quant = value; + return *this; + } + + FmhaProblemBuilder& skip_min_seqlen_q(bool value) + { + problem_.skip_min_seqlen_q = value; + return *this; + } + + FmhaProblemBuilder& sink(bool value) + { + problem_.has_sink = value; + return *this; + } + + FmhaProblemBuilder& kv_cache(int memory_layout, int lookup_table, int page_size) + { + problem_.kv_memory_layout = memory_layout; + problem_.kv_lookup_table = lookup_table; + problem_.page_size = page_size; + return *this; + } + + FmhaProblemBuilder& window(std::int64_t left, std::int64_t right) + { + problem_.window_size_left = left; + problem_.window_size_right = right; + return *this; + } + + FmhaProblemBuilder& sink_size(std::int64_t value) + { + problem_.sink_size = value; + problem_.has_sink = (value > 0); + return *this; + } + + FmhaProblemBuilder& max_seqlen(std::int64_t q, std::int64_t k) + { + problem_.max_seqlen_q = q; + problem_.max_seqlen_k = k; + return *this; + } + + FmhaProblemBuilder& num_splits(std::int64_t value) + { + problem_.num_splits = value; + return *this; + } + + FmhaProblemBuilder& bwd_flags(bool dbias, bool store_randval, bool deterministic) + { + problem_.has_dbias = dbias; + problem_.is_store_randval = store_randval; + problem_.is_deterministic = deterministic; + return *this; + } + + [[nodiscard]] FmhaProblem build() const + { + if(!problem_.is_valid()) + { + throw std::invalid_argument("Invalid FMHA problem: " + problem_.to_string()); + } + + const auto fam = problem_.api_family; + if(fam == FmhaApiFamily::Bwd) + { + if(problem_.has_lse == false) + { + throw std::invalid_argument( + "FMHA BWD requires has_lse=true (LSE from forward pass)"); + } + } + + if(problem_.is_group_mode && problem_.max_seqlen_q <= 0) + { + throw std::invalid_argument("FMHA group mode requires max_seqlen_q > 0"); + } + + return problem_; + } + + private: + FmhaProblem problem_; +}; + +// ============================================================================= +// Backward workspace sizing +// ============================================================================= + +struct FmhaBwdWorkspaceInfo +{ + size_t d_bytes = 0; // B * Hq * Sq * sizeof(float) + size_t dq_acc_bytes = 0; // B * Hq * Sq * Dq * sizeof(float) + size_t rand_val_bytes = 0; // 0 unless is_store_randval + size_t total_bytes = 0; // aligned sum + size_t d_offset = 0; // always 0 + size_t dq_acc_offset = 0; // align(d_bytes, 256) + size_t rand_val_offset = 0; // align(d_bytes + dq_acc_bytes, 256) +}; + +inline FmhaBwdWorkspaceInfo bwd_workspace_info(const FmhaProblem& problem) +{ + constexpr size_t kAlign = 256; + auto align_up = [](size_t n, size_t a) -> size_t { return (n + a - 1) / a * a; }; + + FmhaBwdWorkspaceInfo info; + const auto B = static_cast(problem.batch); + const auto Hq = static_cast(problem.nhead_q); + const auto Sq = static_cast(problem.seqlen_q); + const auto Dq = static_cast(problem.hdim_q); + const auto Sk = static_cast(problem.seqlen_k); + + info.d_bytes = B * Hq * Sq * sizeof(float); + info.dq_acc_bytes = B * Hq * Sq * Dq * sizeof(float); + + if(problem.is_store_randval) + info.rand_val_bytes = B * Hq * Sq * Sk * sizeof(uint8_t); + + info.d_offset = 0; + info.dq_acc_offset = align_up(info.d_bytes, kAlign); + info.rand_val_offset = align_up(info.dq_acc_offset + info.dq_acc_bytes, kAlign); + info.total_bytes = info.rand_val_bytes > 0 + ? align_up(info.rand_val_offset + info.rand_val_bytes, kAlign) + : align_up(info.dq_acc_offset + info.dq_acc_bytes, kAlign); + + return info; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/fmha_registry.hpp b/dispatcher/include/ck_tile/dispatcher/fmha_registry.hpp new file mode 100644 index 0000000000..6c5302d54f --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/fmha_registry.hpp @@ -0,0 +1,63 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/fmha_kernel_instance.hpp" + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +class FmhaRegistry : public BaseRegistry +{ + using Base = BaseRegistry; + + public: + using Priority = ck_tile::dispatcher::Priority; + + FmhaRegistry() = default; + + bool register_kernel(FmhaKernelInstancePtr instance, Priority priority = Priority::Normal); + + [[nodiscard]] FmhaKernelInstancePtr lookup(const std::string& identifier) const; + [[nodiscard]] FmhaKernelInstancePtr lookup(const FmhaKernelKey& key) const; + [[nodiscard]] std::vector get_all() const; + + [[nodiscard]] std::vector + filter(std::function predicate) const; + + [[nodiscard]] std::string export_json(bool include_statistics = true) const; + bool export_json_to_file(const std::string& filename, bool include_statistics = true) const; + + std::size_t filter_by_arch(const std::string& gpu_arch); + + /// Remove kernels whose signature receipt does not match the given receipt_id. + /// Returns the number of kernels removed. + std::size_t filter_by_receipt(int receipt_id); + + /// Return the set of distinct receipt IDs present in the registry. + [[nodiscard]] std::vector available_receipts() const; + + static FmhaRegistry& instance(); +}; + +using FmhaRegistryPtr = std::shared_ptr; + +inline FmhaRegistryPtr make_fmha_registry(const std::string& name = "") +{ + auto reg = std::make_shared(); + if(!name.empty()) + { + reg->set_name(name); + } + return reg; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/fmha_types.hpp b/dispatcher/include/ck_tile/dispatcher/fmha_types.hpp new file mode 100644 index 0000000000..63bd90ec2a --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/fmha_types.hpp @@ -0,0 +1,605 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// FMHA type definitions for the dispatcher. +// +// Fine-grained guards prevent redefinition when example headers are present: +// CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE -- set by fwd kernel wrappers +// CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE -- set by bwd kernel wrappers +// +// fmha_fwd.hpp provides: mask_enum, bias_enum, quant_scale_enum, rope_enum, +// all fwd args/traits, FmhaMasks +// fmha_bwd.hpp provides: mask_enum, bias_enum, bwd args/traits, FmhaMasks +// (but NOT quant_scale_enum, rope_enum) + +#pragma once + +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" + +#include +#include +#include +#include + +// ========================================================================= +// Shared enums: mask_enum and bias_enum +// Provided by both fmha_fwd.hpp and fmha_bwd.hpp (via mask.hpp, bias.hpp). +// Skipped when EITHER example header was included. +// ========================================================================= +#if !defined(CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE) && !defined(CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE) + +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +#endif // shared enums + +// ========================================================================= +// Fwd-only enums: quant_scale_enum, rope_enum +// Only provided by fmha_fwd.hpp (via quant.hpp, rotary.hpp). +// Skipped when fmha_fwd.hpp was included; always provided in bwd-only TUs. +// ========================================================================= +#ifndef CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE + +enum class quant_scale_enum +{ + no_scale = 0, + pertensor = 1, + blockscale = 2, + kv_blockscale = 3, +}; + +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +#endif // fwd-only enums + +// ========================================================================= +// Forward args + traits: skipped when fmha_fwd.hpp was included +// ========================================================================= +#ifndef CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE + +struct fmha_fwd_args +{ + const void* q_ptr = nullptr; + const void* k_ptr = nullptr; + const void* v_ptr = nullptr; + const void* bias_ptr = nullptr; + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + void* rand_val_ptr = nullptr; + void* lse_ptr = nullptr; + void* o_ptr = nullptr; + + const void* seqstart_q_ptr = nullptr; + const void* seqstart_k_ptr = nullptr; + const void* seqlen_q_ptr = nullptr; + const void* seqlen_k_ptr = nullptr; + const void* cu_seqlen_q_ptr = nullptr; + const void* cu_seqlen_k_ptr = nullptr; + const void* block_scale_seqstart_q_ptr = nullptr; + const void* block_scale_seqstart_k_ptr = nullptr; + const void* sink_ptr = nullptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; + + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; +}; + +struct fmha_fwd_pagedkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; + bool is_gappy; + + const void* cache_batch_idx; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + const void* sink_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; + bool is_gappy; + + const void* cache_batch_idx; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + const void* sink_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_o_acc; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; + ck_tile::index_t mask_type; +}; + +struct fmha_fwd_appendkv_args +{ + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + const void* rotary_cos_ptr; + const void* rotary_sin_ptr; + ck_tile::index_t rotary_dim; + bool has_mask; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; + + const void* cache_batch_idx; + const void* sink_ptr; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; +}; + +struct fmha_batch_prefill_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + const void* q_descale_ptr; + const void* k_descale_ptr; + const void* v_descale_ptr; + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* sink_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + int32_t num_total_pages; + ck_tile::index_t page_block_size; + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout; + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; + void* kv_indptr; + void* kv_page_indices; + void* kv_last_page_lens; + void* seqlen_k_ptr; + ck_tile::index_t batch_stride_block_table; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; + + ck_tile::index_t nblock_stride_kv_block_descale = 0; + ck_tile::index_t nhead_stride_kv_block_descale = 0; +}; + +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + bool has_logits_soft_cap; + mask_enum mask_type; + bias_enum bias_type; + bool has_lse; + bool has_dropout; + quant_scale_enum qscale_type; + bool skip_min_seqlen_q = false; + bool has_sink = false; +}; + +struct fmha_fwd_pagedkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + bool has_logits_soft_cap; + mask_enum mask_type; + bias_enum bias_type; + bool has_lse = false; + bool use_pagedkv = true; + bool do_fp8_static_quant = false; + bool skip_min_seqlen_q = false; + bool has_sink = false; +}; + +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + bool has_logits_soft_cap; + mask_enum mask_type; + bias_enum bias_type; + bool has_lse; + bool do_fp8_static_quant = false; + bool has_sink = false; +}; + +struct fmha_fwd_appendkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_v_rowmajor; + rope_enum rope_type; +}; + +struct fmha_batch_prefill_traits : public fmha_fwd_traits +{ + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + int page_size = 1; +}; + +#endif // CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE + +// ========================================================================= +// Backward args + traits: skipped when fmha_bwd.hpp was included +// ========================================================================= +#ifndef CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE + +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + void* dq_acc_ptr; + + const void* seqstart_q_ptr = nullptr; + const void* seqstart_k_ptr = nullptr; + const void* seqlen_q_ptr = nullptr; + const void* seqlen_k_ptr = nullptr; + const void* cu_seqlen_q_ptr = nullptr; + const void* cu_seqlen_k_ptr = nullptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dq_acc; + ck_tile::index_t stride_dq; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::long_index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::long_index_t batch_stride_dq_acc; + ck_tile::index_t batch_stride_dq; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t split_stride_dq_acc; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + float p_undrop; + std::variant, std::pair> + drop_seed_offset; +}; + +struct fmha_bwd_traits +{ + int seqlen_q; + int seqlen_k; + int batch; + int max_seqlen_q; + int max_seqlen_k; + int hdim_q; + int hdim_v; + int nhead_q; + int nhead_k; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; + bool has_dbias; + bool has_dropout; + bool is_store_randval; + bool is_deterministic; +}; + +#endif // CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE + +// ABI safety: when example headers ARE included (in generated kernel TUs), +// verify that the upstream types have the same size as our fallback definitions +// would produce. This catches silent struct drift between the dispatcher's +// fallback types and the upstream example headers. +#if defined(CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE) +static_assert(sizeof(fmha_fwd_traits) >= 40, "fmha_fwd_traits layout may have changed upstream"); +static_assert(sizeof(fmha_fwd_args) >= 300, "fmha_fwd_args layout may have changed upstream"); +#endif +#if defined(CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE) +static_assert(sizeof(fmha_bwd_traits) >= 32, "fmha_bwd_traits layout may have changed upstream"); +static_assert(sizeof(fmha_bwd_args) >= 350, "fmha_bwd_args layout may have changed upstream"); +#endif diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp index 4a734f4c3f..b6ef76e4f8 100644 --- a/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp +++ b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp @@ -59,6 +59,15 @@ class KernelInstance const void** d_ptrs, const Problem& problem, float tolerance = 1e-3f) const = 0; + + /// Enable or disable GPU benchmarking (timing) for this kernel. + /// When disabled, the kernel executes once with no timing overhead + /// (one-shot mode for production use). + void set_benchmarking(bool enable) { benchmarking_ = enable; } + [[nodiscard]] bool benchmarking() const { return benchmarking_; } + + protected: + bool benchmarking_ = true; }; /// Shared pointer type for kernel instances diff --git a/dispatcher/include/ck_tile/dispatcher_fmha.hpp b/dispatcher/include/ck_tile/dispatcher_fmha.hpp new file mode 100644 index 0000000000..55d79bdbf6 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher_fmha.hpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +/// FMHA-only dispatcher header. Does not pull in GEMM or Conv types. + +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/fmha_types.hpp" +#include "ck_tile/dispatcher/fmha_problem.hpp" +#include "ck_tile/dispatcher/fmha_kernel_key.hpp" +#include "ck_tile/dispatcher/fmha_kernel_instance.hpp" +#include "ck_tile/dispatcher/fmha_registry.hpp" +#include "ck_tile/dispatcher/fmha_dispatcher.hpp" +#include "ck_tile/dispatcher/fmha_kernel_decl.hpp" +#include "ck_tile/dispatcher/backends/generated_fmha_backend.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher_gemm.hpp b/dispatcher/include/ck_tile/dispatcher_gemm.hpp index 79317c7399..e9e48f1d4e 100644 --- a/dispatcher/include/ck_tile/dispatcher_gemm.hpp +++ b/dispatcher/include/ck_tile/dispatcher_gemm.hpp @@ -1,6 +1,22 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#pragma once + +/// GEMM-only dispatcher header. Does not pull in Conv or FMHA types. + +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/kernel_config.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" /// GEMM-only dispatcher header -- minimal include for GEMM operations. #pragma once @@ -9,14 +25,5 @@ #include "ck_tile/dispatcher/base_registry.hpp" #include "ck_tile/dispatcher/dispatcher_error.hpp" #include "ck_tile/dispatcher/example_args.hpp" - -// GEMM -#include "ck_tile/dispatcher/kernel_key.hpp" -#include "ck_tile/dispatcher/kernel_config.hpp" -#include "ck_tile/dispatcher/kernel_decl.hpp" -#include "ck_tile/dispatcher/kernel_instance.hpp" -#include "ck_tile/dispatcher/problem.hpp" -#include "ck_tile/dispatcher/registry.hpp" -#include "ck_tile/dispatcher/dispatcher.hpp" #include "ck_tile/dispatcher/json_export.hpp" #include "ck_tile/dispatcher/utils.hpp" diff --git a/dispatcher/python/dispatcher_common.py b/dispatcher/python/dispatcher_common.py index a19ecbdb49..3388e6bf68 100644 --- a/dispatcher/python/dispatcher_common.py +++ b/dispatcher/python/dispatcher_common.py @@ -57,6 +57,22 @@ def get_codegen_dir() -> Path: return get_dispatcher_root() / "codegen" +def detect_gpu_arch(fallback: str = "gfx942") -> str: + """Detect the GPU architecture from rocminfo. Falls back to the given default.""" + import subprocess + + try: + out = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.DEVNULL + ) + for line in out.splitlines(): + if "Name:" in line and "gfx" in line: + return line.split()[-1].strip() + except Exception: + pass + return fallback + + # ============================================================================ # Architecture Filter Data # ============================================================================ diff --git a/dispatcher/python/fmha_utils.py b/dispatcher/python/fmha_utils.py new file mode 100644 index 0000000000..5d3d085496 --- /dev/null +++ b/dispatcher/python/fmha_utils.py @@ -0,0 +1,1842 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +FMHA Dispatcher Python Utilities + +Provides Python wrappers for FMHA dispatcher kernels via ctypes. +Mirrors ctypes_utils.py (GEMM) and grouped_conv_utils.py (Conv). + +Usage: + from fmha_utils import FmhaDispatcherLib, FmhaRunner, FmhaProblem, cpu_attention_fwd + + runner = FmhaRunner.from_prebuilt() + result = runner.run(Q, K, V, problem) +""" + +import ctypes +import json +import os +import subprocess +import sys +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np + + +# ============================================================================= +# Utility helpers +# ============================================================================= + + +try: + from dispatcher_common import detect_gpu_arch, get_dispatcher_root +except ImportError: + # Standalone usage without dispatcher_common on PYTHONPATH + def get_dispatcher_root() -> Path: + return Path(__file__).parent.parent + + def detect_gpu_arch(fallback: str = "gfx950") -> str: + try: + out = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.DEVNULL + ) + for line in out.splitlines(): + if "Name:" in line and "gfx" in line: + return line.split()[-1].strip() + except Exception: + pass + return fallback + + +# ============================================================================= +# Data types +# ============================================================================= + + +@dataclass +class FmhaResult: + success: bool + output: Optional[np.ndarray] = None + time_ms: float = 0.0 + tflops: float = 0.0 + error: str = "" + + +@dataclass +class FmhaProblem: + batch: int = 2 + nhead_q: int = 8 + nhead_k: int = 8 + seqlen_q: int = 128 + seqlen_k: int = 128 + hdim_q: int = 128 + hdim_v: int = 128 + + @property + def scale(self) -> float: + return 1.0 / (self.hdim_q**0.5) + + @property + def num_ops(self) -> int: + sq, sk = self.seqlen_q, self.seqlen_k + return 2 * self.batch * self.nhead_q * sq * sk * (self.hdim_q + self.hdim_v) + + def q_shape(self): + return (self.batch, self.nhead_q, self.seqlen_q, self.hdim_q) + + def k_shape(self): + return (self.batch, self.nhead_k, self.seqlen_k, self.hdim_q) + + def v_shape(self): + return (self.batch, self.nhead_k, self.seqlen_k, self.hdim_v) + + def o_shape(self): + return (self.batch, self.nhead_q, self.seqlen_q, self.hdim_v) + + +@dataclass +class FmhaKernelConfig: + """Complete kernel configuration for FMHA. + + All tile/wave/warp dimensions are explicitly named to match the + GEMM pattern (tile_m, tile_n, tile_k) but extended for FMHA's + two-stage computation (Q*K^T stage 0, Attn*V stage 1). + """ + + # -- Signature: what operation -- + family: str = "fwd" + data_type: str = "fp16" + mode: str = "batch" + vlayout: str = "r" + hdim_q: int = 128 + hdim_v: int = 128 + gfx_arch: str = "gfx950" + + # -- Algorithm: tile shape -- + # Stage 0 (Q * K^T): seqlen_q x seqlen_k x hdim_q + tile_m0: int = 128 # seqlen_q tile + tile_n0: int = 128 # seqlen_k tile + tile_k0: int = 32 # hdim_q tile + # Stage 1 (Attn * V): seqlen_q x hdim_v x seqlen_k + tile_n1: int = 128 # hdim_v tile + tile_k1: int = 32 # seqlen_k tile + tile_k0max: int = 128 # max k0 (alignment) + # BWD extra stages (9-element tile) + tile_bwd6: int = 0 + tile_bwd7: int = 0 + tile_bwd8: int = 0 + + # -- Algorithm: wave config (warps per block) -- + wave_m0: int = 4 + wave_n0: int = 1 + wave_k0: int = 1 + wave_m1: int = 4 + wave_n1: int = 1 + wave_k1: int = 1 + wave_m2: int = 1 + wave_n2: int = 1 + wave_k2: int = 1 + + # -- Algorithm: warp tile (elements per warp) -- + warp_m0: int = 32 + warp_n0: int = 32 + warp_k0: int = 16 + warp_m1: int = 32 + warp_n1: int = 32 + warp_k1: int = 16 + warp_m2: int = 16 + warp_n2: int = 16 + warp_k2: int = 16 + + # -- Algorithm: padding -- + # Values: 0=no pad, 1=pad, 8=pad with 8-byte alignment (BWD-specific) + pad_s: int = 1 + pad_sk: int = 1 + pad_d: int = 1 + pad_dv: int = 1 + + # -- Algorithm: pipeline -- + pipeline: str = "qr_async" + block_per_cu: int = -1 + num_wave_groups: int = 1 + + # -- Signature: features -- + mask: str = "no" + bias: str = "no" + lse: bool = False + dropout: bool = False + qscale: str = "no" + rope: str = "none" + logits: bool = False + paged_kv: bool = False + sink: bool = False + skip_min_seqlen_q: bool = False + page_size: int = 1 + kv_memory_layout: str = "vectorized" + kv_lookup_table: str = "sglang" + deterministic: bool = False + dbias: bool = False + dropout_variant: str = "" # BWD: "no"/"dropout_wg16"/"dropout_wg16_storerandval" + tile_tag: str = "" # extra tile variant discriminator (e.g. "trload", "small") + use_trload: bool = False # BWD dq_dk_dv: use trload pipeline path + + @property + def tile(self) -> Tuple[int, ...]: + base = ( + self.tile_m0, + self.tile_n0, + self.tile_k0, + self.tile_n1, + self.tile_k1, + self.tile_k0max, + ) + if self.family == "bwd_dq_dk_dv" and self.tile_bwd6 > 0: + return base + (self.tile_bwd6, self.tile_bwd7, self.tile_bwd8) + return base + + @property + def wave(self) -> Tuple[int, ...]: + return ( + self.wave_m0, + self.wave_n0, + self.wave_k0, + self.wave_m1, + self.wave_n1, + self.wave_k1, + self.wave_m2, + self.wave_n2, + self.wave_k2, + ) + + @property + def warp(self) -> Tuple[int, ...]: + return ( + self.warp_m0, + self.warp_n0, + self.warp_k0, + self.warp_m1, + self.warp_n1, + self.warp_k1, + self.warp_m2, + self.warp_n2, + self.warp_k2, + ) + + @property + def padding(self) -> Tuple[bool, ...]: + return (self.pad_s, self.pad_sk, self.pad_d, self.pad_dv) + + @property + def name(self) -> str: + s = self.pad_s + k = self.pad_sk + d = self.pad_d + v = self.pad_dv + parts = [ + f"fmha_{self.family}_{self.data_type}", + self.mode, + f"h{self.hdim_q}x{self.hdim_v}" + if self.hdim_q != self.hdim_v + else f"h{self.hdim_q}", + self.pipeline, + f"t{self.tile_m0}x{self.tile_n0}x{self.tile_k0}x{self.tile_n1}x{self.tile_k1}x{self.tile_k0max}" + + (f".{self.tile_tag}" if self.tile_tag else ""), + ] + # Always include warp class for uniform naming + parts.append(f"w{self.warp_m0}x{self.warp_n0}x{self.warp_k0}") + parts.extend( + [ + f"pad{s}{k}{d}{v}", + f"mask={self.mask}", + f"bias={self.bias}", + ] + ) + if self.lse: + parts.append("lse=1") + if self.dropout: + parts.append("drop=1") + if self.logits: + parts.append("logits=1") + if self.sink: + parts.append("sink=1") + if self.skip_min_seqlen_q: + parts.append("skip=1") + if self.qscale != "no": + parts.append(f"qs={self.qscale}") + if self.paged_kv: + parts.append("pkv=1") + if self.rope != "none": + parts.append(f"rope={self.rope}") + if self.page_size != 1: + parts.append(f"ps={self.page_size}") + if self.kv_memory_layout != "vectorized": + parts.append(f"kvl={self.kv_memory_layout}") + if self.kv_lookup_table != "sglang": + parts.append(f"kvt={self.kv_lookup_table}") + if self.deterministic: + parts.append("det=1") + if self.dbias: + parts.append("dbias=1") + if self.dropout_variant and self.dropout_variant != "no": + parts.append(f"drv={self.dropout_variant}") + # Always include block_per_cu for uniform naming + parts.append(f"bpc={self.block_per_cu}") + return "_".join(parts) + + def to_codegen_json(self) -> str: + return json.dumps( + { + "arch": self.gfx_arch, + "signature": { + "family": self.family, + "data_type": self.data_type, + "mode": self.mode, + "vlayout": self.vlayout, + "hdim_q": self.hdim_q, + "hdim_v": self.hdim_v, + "mask": self.mask, + "bias": self.bias, + "lse": self.lse, + "dropout": self.dropout, + "qscale": self.qscale, + "rope": self.rope, + "logits": self.logits, + "paged_kv": self.paged_kv, + "fp8_static_quant": False, + "skip_min_seqlen_q": self.skip_min_seqlen_q, + "sink": self.sink, + "dbias": self.dbias, + "store_randval": "storerandval" in self.dropout_variant, + "deterministic": self.deterministic, + "dropout_variant": self.dropout_variant, + "kv_memory_layout": self.kv_memory_layout, + "kv_lookup_table": self.kv_lookup_table, + "page_size": self.page_size, + }, + "algorithm": { + "pipeline": self.pipeline, + "tile": list(self.tile), + "wave": list(self.wave), + "warp": list(self.warp), + "padding": list(self.padding), + "block_per_cu": self.block_per_cu, + "num_wave_groups": self.num_wave_groups, + "max_splits_log2": 0, + "max_seq_len_q": 0, + "use_trload": self.use_trload, + }, + } + ) + + +# ============================================================================= +# CPU reference +# ============================================================================= + + +def _float32_to_bf16(arr: np.ndarray) -> np.ndarray: + """Convert float32 array to bf16 stored as uint16 (truncate lower 16 bits).""" + return arr.astype(np.float32).view(np.uint32).__rshift__(16).astype(np.uint16) + + +def _bf16_to_float32(arr: np.ndarray) -> np.ndarray: + """Convert bf16 (uint16) array back to float32.""" + return (arr.astype(np.uint32) << 16).view(np.float32) + + +def cpu_attention_fwd( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + scale: float, + mask_type: int = 0, +) -> np.ndarray: + """CPU reference: scaled dot-product attention (supports GQA and causal mask). + + Args: + Q: [batch, nhead_q, seqlen_q, hdim_q] float32 + K: [batch, nhead_k, seqlen_k, hdim_q] float32 + V: [batch, nhead_k, seqlen_k, hdim_v] float32 + mask_type: 0=no mask, 1=causal top-left, 2=causal bottom-right + + Returns: + O: [batch, nhead_q, seqlen_q, hdim_v] float32 + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + if mask_type in (1, 2): + sq, sk = S.shape[-2], S.shape[-1] + row = np.arange(sq).reshape(sq, 1) + col = np.arange(sk).reshape(1, sk) + if mask_type == 1: # top-left causal + causal_mask = col <= row + else: # bottom-right causal + causal_mask = col <= (row + sk - sq) + S = np.where(causal_mask, S, -1e9) + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + P = S_exp / S_exp.sum(axis=-1, keepdims=True) + return np.matmul(P, V) + + +def cpu_attention_fwd_with_intermediates( + Q: np.ndarray, K: np.ndarray, V: np.ndarray, scale: float +) -> tuple: + """CPU reference forward returning (output, P) for backward use. + + Same as cpu_attention_fwd but also returns the softmax probability matrix P. + """ + nhead_q = Q.shape[1] + nhead_k = K.shape[1] + if nhead_q != nhead_k: + ratio = nhead_q // nhead_k + K = np.repeat(K, ratio, axis=1) + V = np.repeat(V, ratio, axis=1) + S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale + S_max = S.max(axis=-1, keepdims=True) + S_exp = np.exp(S - S_max) + P = S_exp / S_exp.sum(axis=-1, keepdims=True) + out = np.matmul(P, V) + return out, P + + +def cpu_attention_bwd( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + dO: np.ndarray, + P: np.ndarray, + scale: float, +) -> tuple: + """CPU reference backward. Returns (dQ, dK, dV). + + Args: + Q, K, V: forward inputs [batch, heads, seq, dim] + out: forward output + dO: gradient of output + P: softmax probabilities from forward + scale: attention scale factor + """ + D = (dO * out).sum(axis=-1, keepdims=True) + dP = np.matmul(dO, V.transpose(0, 1, 3, 2)) + dS = P * (dP - D) + dQ = np.matmul(dS, K) * scale + dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale + dV = np.matmul(P.transpose(0, 1, 3, 2), dO) + return dQ, dK, dV + + +# ============================================================================= +# Low-level ctypes wrapper +# ============================================================================= + + +class FmhaDispatcherLib: + """Wrapper for the FMHA dispatcher shared library (libdispatcher_fmha_lib.so).""" + + SEARCH_PATHS = [ + "build/examples/libdispatcher_fmha_lib.so", + "build/libdispatcher_fmha_lib.so", + "build/lib/libdispatcher_fmha_lib.so", + ] + + def __init__(self, lib: ctypes.CDLL, path: Path): + self._lib = lib + self.path = path + self._setup() + + def _setup(self): + lib = self._lib + lib.fmha_dispatcher_initialize.argtypes = [ctypes.c_char_p] + lib.fmha_dispatcher_initialize.restype = ctypes.c_int + lib.fmha_dispatcher_run_fwd.argtypes = [ + ctypes.c_void_p, # q + ctypes.c_void_p, # k + ctypes.c_void_p, # v + ctypes.c_void_p, # o + ctypes.c_int, # batch + ctypes.c_int, # nhead_q + ctypes.c_int, # nhead_k + ctypes.c_int, # seqlen_q + ctypes.c_int, # seqlen_k + ctypes.c_int, # hdim_q + ctypes.c_int, # hdim_v + ctypes.c_float, # scale + ctypes.c_int, # mask_type + ctypes.c_int, # bias_type + ctypes.c_int, # has_lse + ctypes.c_int, # has_dropout + ctypes.c_int, # traits_hdim_q (0=same as hdim_q) + ctypes.c_int, # traits_hdim_v (0=same as hdim_v) + ctypes.c_int, # is_v_rowmajor (1=row, 0=col) + ctypes.c_int, # perm (1=BHSD, 0=BSHD) + ctypes.c_char_p, # data_type ("fp16", "bf16") + ctypes.c_int, # is_group_mode + ctypes.c_int, # window_left (-1=no window) + ctypes.c_int, # window_right (-1=no window, 0=causal) + ctypes.c_int, # has_logits + ctypes.c_int, # has_sink + ctypes.c_int, # has_skip + ctypes.POINTER(ctypes.c_float), # time_ms_out + ] + lib.fmha_dispatcher_run_fwd.restype = ctypes.c_int + lib.fmha_dispatcher_run_bwd.argtypes = [ + ctypes.c_void_p, # q + ctypes.c_void_p, # k + ctypes.c_void_p, # v + ctypes.c_void_p, # o + ctypes.c_void_p, # lse + ctypes.c_void_p, # do + ctypes.c_void_p, # dq + ctypes.c_void_p, # dk + ctypes.c_void_p, # dv + ctypes.c_int, # batch + ctypes.c_int, # nhead_q + ctypes.c_int, # nhead_k + ctypes.c_int, # seqlen_q + ctypes.c_int, # seqlen_k + ctypes.c_int, # hdim_q + ctypes.c_int, # hdim_v + ctypes.c_float, # scale + ctypes.c_char_p, # data_type_str + ctypes.c_int, # mask_type_int + ctypes.c_int, # bias_type_int + ctypes.c_int, # has_dropout + ctypes.c_int, # has_dbias + ctypes.c_int, # is_deterministic + ctypes.c_int, # is_group_mode + ctypes.c_int, # is_store_randval + ctypes.c_int, # tile_n0 (kN0 for nsplits computation) + ctypes.POINTER(ctypes.c_float), # time_ms_out + ] + lib.fmha_dispatcher_run_bwd.restype = ctypes.c_int + + # Split-KV forward + lib.fmha_dispatcher_run_splitkv.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_float, + ctypes.c_int, # mask_type + ctypes.c_int, # num_splits + ctypes.c_int, # is_v_rowmajor + ctypes.c_char_p, + ctypes.c_int, # has_lse + ctypes.c_int, # is_group_mode + ctypes.c_int, # perm + ctypes.c_int, # has_logits + ctypes.c_int, # bias_type + ctypes.c_int, # has_sink + ctypes.c_int, # paged_kv + ctypes.c_int, # page_block_size + ctypes.c_int, # window_left + ctypes.c_int, # window_right + ctypes.POINTER(ctypes.c_float), + ] + lib.fmha_dispatcher_run_splitkv.restype = ctypes.c_int + + # Paged-KV forward + lib.fmha_dispatcher_run_pagedkv.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_float, + ctypes.c_int, # mask_type + ctypes.c_int, # page_block_size + ctypes.c_int, # is_v_rowmajor + ctypes.c_char_p, + ctypes.c_int, # has_lse + ctypes.c_int, # has_logits + ctypes.c_int, # has_sink + ctypes.c_int, # skip_min_seqlen_q + ctypes.c_int, # bias_type + ctypes.POINTER(ctypes.c_float), + ] + lib.fmha_dispatcher_run_pagedkv.restype = ctypes.c_int + + # Append-KV + lib.fmha_dispatcher_run_appendkv.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, # is_v_rowmajor + ctypes.c_int, # rope_type + ctypes.c_int, # paged_kv + ctypes.c_int, # page_block_size + ctypes.c_char_p, + ctypes.POINTER(ctypes.c_float), + ] + lib.fmha_dispatcher_run_appendkv.restype = ctypes.c_int + + # Batch Prefill + lib.fmha_dispatcher_run_batch_prefill.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_float, + ctypes.c_int, # mask_type + ctypes.c_int, # bias_type + ctypes.c_int, # page_block_size + ctypes.c_int, # kv_layout_int + ctypes.c_int, # kv_lookup_int + ctypes.c_int, # is_v_rowmajor + ctypes.c_char_p, + ctypes.c_int, # has_lse + ctypes.c_int, # has_dropout + ctypes.c_int, # has_logits + ctypes.c_int, # has_sink + ctypes.c_int, # skip_min_seqlen_q + ctypes.POINTER(ctypes.c_float), + ] + lib.fmha_dispatcher_run_batch_prefill.restype = ctypes.c_int + + lib.fmha_dispatcher_kernel_count.argtypes = [] + lib.fmha_dispatcher_kernel_count.restype = ctypes.c_int + lib.fmha_dispatcher_cleanup.argtypes = [] + lib.fmha_dispatcher_cleanup.restype = None + + @classmethod + def find(cls) -> Optional["FmhaDispatcherLib"]: + root = get_dispatcher_root() + for rel in cls.SEARCH_PATHS: + path = root / rel + if path.exists(): + try: + lib = ctypes.CDLL(str(path)) + return cls(lib, path) + except OSError: + continue + return None + + @classmethod + def load(cls, path: str) -> "FmhaDispatcherLib": + lib = ctypes.CDLL(path) + return cls(lib, Path(path)) + + def initialize(self, arch: str = "gfx950") -> bool: + return self._lib.fmha_dispatcher_initialize(arch.encode()) == 0 + + def run_bwd( + self, + q: ctypes.c_void_p, + k: ctypes.c_void_p, + v: ctypes.c_void_p, + o: ctypes.c_void_p, + lse: ctypes.c_void_p, + do_grad: ctypes.c_void_p, + dq: ctypes.c_void_p, + dk: ctypes.c_void_p, + dv: ctypes.c_void_p, + prob: FmhaProblem, + data_type: str = "fp16", + mask_type: int = 0, + bias_type: int = 0, + has_dropout: bool = False, + has_dbias: bool = False, + is_deterministic: bool = False, + is_group_mode: bool = False, + is_store_randval: bool = False, + tile_n0: int = 128, + ) -> Tuple[int, float]: + time_ms = ctypes.c_float(0.0) + rc = self._lib.fmha_dispatcher_run_bwd( + q, + k, + v, + o, + lse, + do_grad, + dq, + dk, + dv, + prob.batch, + prob.nhead_q, + prob.nhead_k, + prob.seqlen_q, + prob.seqlen_k, + prob.hdim_q, + prob.hdim_v, + prob.scale, + data_type.encode(), + ctypes.c_int(mask_type), + ctypes.c_int(bias_type), + ctypes.c_int(int(has_dropout)), + ctypes.c_int(int(has_dbias)), + ctypes.c_int(int(is_deterministic)), + ctypes.c_int(int(is_group_mode)), + ctypes.c_int(int(is_store_randval)), + ctypes.c_int(tile_n0), + ctypes.byref(time_ms), + ) + return rc, time_ms.value + + def kernel_count(self) -> int: + return self._lib.fmha_dispatcher_kernel_count() + + def cleanup(self): + self._lib.fmha_dispatcher_cleanup() + + +# ============================================================================= +# High-level GPU runner (mirrors GpuGroupedConvRunner) +# ============================================================================= + + +class FmhaRunner: + """High-level FMHA runner with NumPy interface and HIP memory management.""" + + HIP_MEMCPY_H2D = 1 + HIP_MEMCPY_D2H = 2 + + def __init__(self, dispatch_lib: FmhaDispatcherLib, arch: str = "gfx950"): + self._lib = dispatch_lib + self._arch = arch + self._hip = None + self._load_hip() + if not dispatch_lib.initialize(arch): + raise RuntimeError("Failed to initialize FMHA dispatcher") + + def _load_hip(self): + for name in ["libamdhip64.so", "libamdhip64.so.6"]: + try: + self._hip = ctypes.CDLL(name) + self._hip.hipMalloc.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_size_t, + ] + self._hip.hipMalloc.restype = ctypes.c_int + self._hip.hipFree.argtypes = [ctypes.c_void_p] + self._hip.hipFree.restype = ctypes.c_int + self._hip.hipMemcpy.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ] + self._hip.hipMemcpy.restype = ctypes.c_int + self._hip.hipMemset.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_size_t, + ] + self._hip.hipMemset.restype = ctypes.c_int + return + except OSError: + continue + raise RuntimeError("Could not load libamdhip64.so") + + @classmethod + def from_prebuilt(cls, arch: Optional[str] = None) -> "FmhaRunner": + arch = arch or detect_gpu_arch() + lib = FmhaDispatcherLib.find() + if lib is None: + raise RuntimeError( + "FMHA dispatcher library not found. Build with:\n" + " cd dispatcher/build && cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON && make dispatcher_fmha_lib" + ) + return cls(lib, arch) + + @classmethod + def from_library(cls, path: str, arch: Optional[str] = None) -> "FmhaRunner": + arch = arch or detect_gpu_arch() + return cls(FmhaDispatcherLib.load(path), arch) + + def run( + self, + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + prob: FmhaProblem, + mask_type: int = 0, + bias_type: int = 0, + has_lse: int = 0, + has_dropout: int = 0, + has_logits: int = 0, + has_sink: int = 0, + has_skip: int = 0, + api_family: str = "fwd", + data_type: str = "fp16", + **kwargs, + ) -> "FmhaResult": + """Run FMHA forward on GPU with automatic HIP memory management. + + Args: + Q: [batch, nhead_q, seqlen_q, hdim_q] float16 + K: [batch, nhead_k, seqlen_k, hdim_q] float16 + V: [batch, nhead_k, seqlen_k, hdim_v] float16 + + Returns: + FmhaResult with output array, timing, TFLOPS + """ + # Map CK dtype to numpy dtype for buffer allocation. + # bf16 is stored as uint16 (upper 16 bits of float32). + # fp8 uses uint8 (1 byte per element). + _NP_DTYPE = { + "fp16": np.float16, + "bf16": np.uint16, + "fp32": np.float32, + "fp8bf16": np.uint8, + "fp8fp32": np.uint8, + "bf8": np.uint8, + } + _NP_OUT_DTYPE = { + "fp16": np.float16, + "bf16": np.uint16, + "fp32": np.float32, + "fp8bf16": np.float16, + "fp8fp32": np.float32, + "bf8": np.uint8, + } + in_dt = _NP_DTYPE.get(data_type, np.float16) + out_dt = _NP_OUT_DTYPE.get(data_type, np.float16) + if data_type == "bf16": + Q_c = _float32_to_bf16(np.ascontiguousarray(Q.astype(np.float32))) + K_c = _float32_to_bf16(np.ascontiguousarray(K.astype(np.float32))) + V_c = _float32_to_bf16(np.ascontiguousarray(V.astype(np.float32))) + else: + Q_c = np.ascontiguousarray(Q.astype(in_dt)) + K_c = np.ascontiguousarray(K.astype(in_dt)) + V_c = np.ascontiguousarray(V.astype(in_dt)) + O_c = np.zeros(prob.o_shape(), dtype=out_dt) + + d_q, d_k, d_v, d_o = (ctypes.c_void_p() for _ in range(4)) + + try: + self._hip.hipMalloc(ctypes.byref(d_q), Q_c.nbytes) + self._hip.hipMalloc(ctypes.byref(d_k), K_c.nbytes) + self._hip.hipMalloc(ctypes.byref(d_v), V_c.nbytes) + self._hip.hipMalloc(ctypes.byref(d_o), O_c.nbytes) + + self._hip.hipMemcpy(d_q, Q_c.ctypes.data, Q_c.nbytes, self.HIP_MEMCPY_H2D) + self._hip.hipMemcpy(d_k, K_c.ctypes.data, K_c.nbytes, self.HIP_MEMCPY_H2D) + self._hip.hipMemcpy(d_v, V_c.ctypes.data, V_c.nbytes, self.HIP_MEMCPY_H2D) + self._hip.hipMemset(d_o, 0, O_c.nbytes) + + time_ms = ctypes.c_float(0.0) + lib = self._lib._lib + + is_v_rowmajor = kwargs.get("is_v_rowmajor", 1) + is_group_mode = kwargs.get("is_group_mode", 0) + perm = kwargs.get("perm", 1) + window_left = kwargs.get("window_left", -1) + window_right = kwargs.get("window_right", -1) + num_splits = kwargs.get("num_splits", 4) + page_size = kwargs.get("page_size", 64) + kv_layout = kwargs.get("kv_layout", 0) + kv_lookup = kwargs.get("kv_lookup", 0) + traits_hdim_q = kwargs.get("traits_hdim_q", 0) + traits_hdim_v = kwargs.get("traits_hdim_v", 0) + + if api_family == "splitkv": + paged_kv = kwargs.get("paged_kv", 0) + rc = lib.fmha_dispatcher_run_splitkv( + d_q, + d_k, + d_v, + d_o, + prob.batch, + prob.nhead_q, + prob.nhead_k, + prob.seqlen_q, + prob.seqlen_k, + prob.hdim_q, + prob.hdim_v, + prob.scale, + mask_type, + num_splits, + is_v_rowmajor, + data_type.encode(), + has_lse, + is_group_mode, + perm, + has_logits, + bias_type, + has_sink, + paged_kv, + page_size, + window_left, + window_right, + ctypes.byref(time_ms), + ) + elif api_family == "pagedkv": + rc = lib.fmha_dispatcher_run_pagedkv( + d_q, + d_k, + d_v, + d_o, + prob.batch, + prob.nhead_q, + prob.nhead_k, + prob.seqlen_q, + prob.seqlen_k, + prob.hdim_q, + prob.hdim_v, + prob.scale, + mask_type, + page_size, + is_v_rowmajor, + data_type.encode(), + has_lse, + has_logits, + has_sink, + has_skip, + bias_type, + ctypes.byref(time_ms), + ) + elif api_family == "appendkv": + seqlen_knew = kwargs.get("seqlen_knew", prob.seqlen_k) + rc = lib.fmha_dispatcher_run_appendkv( + Q_c.ctypes.data, + K_c.ctypes.data, + V_c.ctypes.data, + prob.batch, + prob.nhead_q, + prob.nhead_k, + prob.seqlen_q, + seqlen_knew, + prob.hdim_q, + prob.hdim_v, + is_v_rowmajor, + kwargs.get("rope_type", 0), + kwargs.get("paged_kv", 0), + page_size, + data_type.encode(), + ctypes.byref(time_ms), + ) + elif api_family == "batch_prefill": + skip_min_sq = kwargs.get("skip_min_seqlen_q", 0) + rc = lib.fmha_dispatcher_run_batch_prefill( + d_q, + d_k, + d_v, + d_o, + prob.batch, + prob.nhead_q, + prob.nhead_k, + prob.seqlen_q, + prob.seqlen_k, + prob.hdim_q, + prob.hdim_v, + prob.scale, + mask_type, + bias_type, + page_size, + kv_layout, + kv_lookup, + is_v_rowmajor, + data_type.encode(), + has_lse, + has_dropout, + has_logits, + has_sink, + skip_min_sq, + ctypes.byref(time_ms), + ) + else: + rc = lib.fmha_dispatcher_run_fwd( + d_q, + d_k, + d_v, + d_o, + prob.batch, + prob.nhead_q, + prob.nhead_k, + prob.seqlen_q, + prob.seqlen_k, + prob.hdim_q, + prob.hdim_v, + prob.scale, + mask_type, + bias_type, + has_lse, + has_dropout, + traits_hdim_q, + traits_hdim_v, + is_v_rowmajor, + perm, + data_type.encode(), + is_group_mode, + window_left, + window_right, + has_logits, + has_sink, + has_skip, + ctypes.byref(time_ms), + ) + + if rc != 0: + return FmhaResult(success=False, error=f"Kernel failed (rc={rc})") + + self._hip.hipMemcpy(O_c.ctypes.data, d_o, O_c.nbytes, self.HIP_MEMCPY_D2H) + + # Convert bf16 output (uint16) back to float32 for comparison + if data_type == "bf16": + O_c = _bf16_to_float32(O_c) + + # appendkv is a memory op (KV cache copy), not compute -- no TFLOPS + ops = 0 if api_family == "appendkv" else prob.num_ops + tflops = ( + ops / (time_ms.value * 1e-3) / 1e12 + if time_ms.value > 0 and ops > 0 + else 0.0 + ) + return FmhaResult( + success=True, output=O_c, time_ms=time_ms.value, tflops=tflops + ) + + finally: + for d in [d_q, d_k, d_v, d_o]: + if d.value: + self._hip.hipFree(d) + + def run_bwd( + self, + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + out: np.ndarray, + LSE: np.ndarray, + dO: np.ndarray, + prob: FmhaProblem, + data_type: str = "fp16", + mask_type: int = 0, + bias_type: int = 0, + has_dropout: bool = False, + has_dbias: bool = False, + is_deterministic: bool = False, + is_group_mode: bool = False, + is_store_randval: bool = False, + tile_n0: int = 128, + ) -> "FmhaResult": + """Run FMHA backward on GPU with automatic HIP memory management. + + Returns FmhaResult with dQ, dK, dV packed in output as a tuple. + """ + _NP_DTYPE = { + "fp16": np.float16, + "bf16": np.float16, + "fp32": np.float32, + "fp8bf16": np.uint8, + "fp8fp32": np.uint8, + "bf8": np.uint8, + } + in_dt = _NP_DTYPE.get(data_type, np.float16) + Q_c = np.ascontiguousarray(Q.astype(in_dt)) + K_c = np.ascontiguousarray(K.astype(in_dt)) + V_c = np.ascontiguousarray(V.astype(in_dt)) + O_c = np.ascontiguousarray(out.astype(in_dt)) + LSE_c = np.ascontiguousarray(LSE.astype(np.float32)) + dO_c = np.ascontiguousarray(dO.astype(in_dt)) + dQ_c = np.zeros_like(Q_c) + dK_c = np.zeros_like(K_c) + dV_c = np.zeros_like(V_c) + + ptrs = [ctypes.c_void_p() for _ in range(9)] + d_q, d_k, d_v, d_o, d_lse, d_do, d_dq, d_dk, d_dv = ptrs + + try: + for d, arr in zip(ptrs[:6], [Q_c, K_c, V_c, O_c, LSE_c, dO_c]): + self._hip.hipMalloc(ctypes.byref(d), arr.nbytes) + self._hip.hipMemcpy(d, arr.ctypes.data, arr.nbytes, self.HIP_MEMCPY_H2D) + for d, arr in zip(ptrs[6:], [dQ_c, dK_c, dV_c]): + self._hip.hipMalloc(ctypes.byref(d), arr.nbytes) + self._hip.hipMemset(d, 0, arr.nbytes) + + rc, elapsed = self._lib.run_bwd( + d_q, + d_k, + d_v, + d_o, + d_lse, + d_do, + d_dq, + d_dk, + d_dv, + prob, + data_type, + mask_type=mask_type, + bias_type=bias_type, + has_dropout=has_dropout, + has_dbias=has_dbias, + is_deterministic=is_deterministic, + is_group_mode=is_group_mode, + is_store_randval=is_store_randval, + tile_n0=tile_n0, + ) + + if rc != 0: + return FmhaResult(success=False, error=f"BWD kernel failed (rc={rc})") + + for d, arr in zip(ptrs[6:], [dQ_c, dK_c, dV_c]): + self._hip.hipMemcpy(arr.ctypes.data, d, arr.nbytes, self.HIP_MEMCPY_D2H) + + tflops = prob.num_ops / (elapsed * 1e-3) / 1e12 if elapsed > 0 else 0.0 + return FmhaResult( + success=True, + output=(dQ_c, dK_c, dV_c), + time_ms=elapsed, + tflops=tflops, + ) + finally: + for d in ptrs: + if d.value: + self._hip.hipFree(d) + + @property + def kernel_count(self) -> int: + return self._lib.kernel_count() + + @property + def library_path(self) -> str: + return str(self._lib.path) + + def cleanup(self): + self._lib.cleanup() + + +# ============================================================================= +# JIT Build Support (mirrors setup_multiple_gemm_dispatchers) +# ============================================================================= + + +@dataclass +class FmhaSetupResult: + success: bool + config: Optional[FmhaKernelConfig] = None + runner: Optional[FmhaRunner] = None + library_path: str = "" + error: str = "" + build_time_s: float = 0.0 + + +def _build_static_lib(root: Path) -> Optional[Path]: + """Build libck_tile_dispatcher.a via cmake if not already present.""" + build_dir = root / "build" + build_dir.mkdir(parents=True, exist_ok=True) + hipcc = _find_hipcc() + cmake_cmd = ["cmake", str(root), f"-DCMAKE_CXX_COMPILER={hipcc}"] + r = subprocess.run(cmake_cmd, cwd=str(build_dir), capture_output=True, text=True) + if r.returncode != 0: + print( + f"Warning: cmake failed for dispatcher lib: {r.stderr[:200]}", + file=sys.stderr, + ) + return None + make_cmd = ["make", "ck_tile_dispatcher", f"-j{os.cpu_count() or 4}"] + r = subprocess.run(make_cmd, cwd=str(build_dir), capture_output=True, text=True) + if r.returncode != 0: + print( + f"Warning: make failed for dispatcher lib: {r.stderr[:200]}", + file=sys.stderr, + ) + return None + lib_path = build_dir / "libck_tile_dispatcher.a" + return lib_path if lib_path.exists() else None + + +def _find_static_lib() -> Optional[Path]: + root = get_dispatcher_root() + for rel in ["build/libck_tile_dispatcher.a", "build/lib/libck_tile_dispatcher.a"]: + p = root / rel + if p.exists(): + return p + # Auto-build if not found + print(" Building libck_tile_dispatcher.a (first time)...", file=sys.stderr) + return _build_static_lib(root) + + +def _find_hipcc() -> str: + for path in ["/opt/rocm/bin/hipcc", "/usr/bin/hipcc"]: + if os.path.exists(path): + return path + return "hipcc" + + +def fmha_compile_flags(arch: str, hipcc: str = "", family: str = "") -> List[str]: + """Base hipcc flags for compiling FMHA kernels. Shared by JIT and tile engine. + + Source: example/ck_tile/01_fmha/CMakeLists.txt — mirrors CK's own build + flags to ensure parity. Key defines: + - CK_TILE_FMHA_FWD_FAST_EXP2: enables fast exp2 on gfx9 (CDNA) + - CK_TILE_USE_OCP_FP8: uses OCP standard fp8 format + - CK_GFX950_SUPPORT / CK_USE_GFX950: enables gfx950-specific code paths + - CK_USE_XDL: enables MFMA (matrix fused multiply-add) instructions + - CK_TILE_USE_WMMA: 0 for CDNA (uses MFMA instead) + - CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3: BWD bf16 conversion mode + """ + if not hipcc: + hipcc = _find_hipcc() + root = get_dispatcher_root() + flags = [ + hipcc, + "-c", + "-fPIC", + "-O3", + "-DNDEBUG", + f"--offload-arch={arch}", + "-std=c++17", + f"-I{root.parent / 'include'}", + f"-I{root / 'include'}", + f"-I{root.parent}", + "-Wno-undefined-func-template", + "-Wno-float-equal", + "-fgpu-flush-denormals-to-zero", + "-fno-offload-uniform-block", + "-mllvm", + "--lsr-drop-solution=1", + "-mllvm", + "-enable-post-misched=0", + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + ] + if arch.startswith("gfx9"): + flags.append("-DCK_TILE_FMHA_FWD_FAST_EXP2=1") + flags.append("-DCK_TILE_USE_OCP_FP8") + flags.append("-DCK_GFX950_SUPPORT") + flags.append("-DCK_USE_GFX950") + flags.append("-DCK_USE_GFX94") + flags.append("-DCK_USE_XDL") + flags.append("-DCK_TILE_USE_WMMA=0") + else: + flags.append("-DCK_TILE_FMHA_FWD_FAST_EXP2=0") + + # API enablement flags (match CMakeLists.txt conditional defines) + flags.append("-DCK_TILE_FMHA_FWD_SPLITKV_API=1") + flags.append("-DCK_TILE_FMHA_FWD_APPENDKV_API=1") + flags.append("-DCK_TILE_FMHA_FWD_PAGEDKV_API=1") + + # BWD-specific flags + if family.startswith("bwd"): + flags.append("-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3") + + return flags + + +def _make_splitkv_combine_config(splitkv_cfg: FmhaKernelConfig) -> FmhaKernelConfig: + """Create a matching fwd_splitkv_combine config for a fwd_splitkv config. + + Source: fmha_fwd.py splitkv_combine tile — fixed (32, hdim_v, 32, 32) tile. + The combine_bn1=32 comes from specs.py load_arch_specs() splitkv_combine dict. + The combine kernel merges partial results from the split stage into the + final output. Must be in the same .so as the split kernel for the + 2-stage splitkv pipeline. + """ + import copy + + comb = copy.copy(splitkv_cfg) + comb.family = "fwd_splitkv_combine" + comb.pipeline = "splitkv_combine" + hv = splitkv_cfg.hdim_v + comb.hdim_q = hv + comb.hdim_v = hv + comb.tile_m0 = 32 + comb.tile_n0 = hv + comb.tile_k0 = 32 + comb.tile_n1 = 32 + comb.tile_k1 = 0 + comb.tile_k0max = 0 + comb.pad_s = 1 if splitkv_cfg.mode == "group" else 0 + comb.pad_sk = 1 + comb.pad_d = 1 + comb.pad_dv = 1 + comb.lse = True + # Combine doesn't use mask/bias/etc., but the dispatcher's supports() check + # matches the combine kernel's signature against the problem traits. + # Keep them from the split config so the signatures match. + comb.dropout = False + comb.skip_min_seqlen_q = False + comb.qscale = "no" + comb.rope = "none" + return comb + + +def _make_bwd_dot_do_o_config(dq_cfg: FmhaKernelConfig) -> FmhaKernelConfig: + """Create a matching bwd_dot_do_o config for a bwd_dq_dk_dv config. + + Source: fmha_bwd.py FmhaBwdDotDoOTileSize — fixed tile (64, max(hv,128), 32). + Warp tile (32,32,16) with 4 waves in M = standard fp16/bf16 MFMA config. + The dot_do_o kernel computes d = rowsum(O * dO) and must be in the same + .so as the dq_dk_dv kernel for the 2-stage BWD pipeline. + """ + import copy + + dot = copy.copy(dq_cfg) + dot.family = "bwd_dot_do_o" + dot.pipeline = "qr" + hq, hv = dq_cfg.hdim_q, dq_cfg.hdim_v + dot.tile_m0 = 64 + dot.tile_n0 = max(hv, 128) + dot.tile_k0 = 32 + dot.tile_n1 = max(hv, 128) + dot.tile_k1 = 32 + dot.tile_k0max = max(hq, 128) + dot.wave_m0 = 4 + dot.wave_n0 = 1 + dot.wave_k0 = 1 + dot.wave_m1 = 4 + dot.wave_n1 = 1 + dot.wave_k1 = 1 + dot.warp_m0 = 32 + dot.warp_n0 = 32 + dot.warp_k0 = 16 + dot.warp_m1 = 32 + dot.warp_n1 = 32 + dot.warp_k1 = 16 + dot.use_trload = False + # dot_do_o uses all-padded for maximum compatibility + dot.pad_s = 1 + dot.pad_sk = 1 + dot.pad_d = 1 + dot.pad_dv = 1 + # BWD traits don't have logits/sink/skip/lse/paged_kv -- from_invocation + # defaults them to false/0. The dot_do_o signature must match these defaults. + dot.logits = False + dot.sink = False + dot.skip_min_seqlen_q = False + dot.lse = False + dot.paged_kv = False + dot.qscale = "no" + dot.rope = "no" + # dot_do_o must match the problem's is_store_randval (from traits); + # keep dropout_variant as-is so store_randval matches + return dot + + +def setup_fmha_dispatcher( + config: FmhaKernelConfig, + output_dir: Optional[Path] = None, + verbose: bool = False, +) -> FmhaSetupResult: + """JIT-compile a single FMHA kernel and return a runner. + + Cached: if the .so already exists, loads it directly (~1ms). + Fresh build: codegen → parallel compile (kernel + ctypes) → link. + """ + import time + + t0 = time.perf_counter() + + root = get_dispatcher_root() + codegen_dir = root / "codegen" + ctypes_src = root / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp" + static_lib = _find_static_lib() + hipcc = _find_hipcc() + + if output_dir is None: + output_dir = root / "build" / "examples" / f"fmha_jit_{config.name}" + output_dir.mkdir(parents=True, exist_ok=True) + + lib_name = f"libdispatcher_fmha_{config.name}.so" + lib_path = output_dir / lib_name + + # Cache hit: .so already exists, just load + if lib_path.exists(): + try: + runner = FmhaRunner.from_library(str(lib_path), config.gfx_arch) + return FmhaSetupResult( + success=True, + config=config, + runner=runner, + library_path=str(lib_path), + build_time_s=time.perf_counter() - t0, + ) + except Exception: + pass # stale .so, rebuild + + if not static_lib: + return FmhaSetupResult( + success=False, config=config, error="libck_tile_dispatcher.a not found" + ) + if not ctypes_src.exists(): + return FmhaSetupResult( + success=False, config=config, error="fmha_ctypes_lib.cpp not found" + ) + + # Step 1: Codegen + # BWD dq_dk_dv needs a matching dot_do_o kernel in the same .so + # BWD dq_dk_dv needs matching dot_do_o kernel for the 2-stage pipeline + if config.family == "bwd_dq_dk_dv": + dot_cfg = _make_bwd_dot_do_o_config(config) + config_json_str = json.dumps( + [ + json.loads(dot_cfg.to_codegen_json()), + json.loads(config.to_codegen_json()), + ] + ) + else: + config_json_str = config.to_codegen_json() + gen_cmd = [ + sys.executable, + str(codegen_dir / "fmha" / "generate_fallback.py"), + "--output-dir", + str(output_dir), + "--gpu-target", + config.gfx_arch, + "--config-json", + config_json_str, + ] + r = subprocess.run(gen_cmd, capture_output=True, text=True, cwd=str(codegen_dir)) + if r.returncode != 0: + return FmhaSetupResult( + success=False, config=config, error=f"Codegen failed: {r.stderr[:500]}" + ) + + dispatch_header = output_dir / "fmha_python_dispatch.hpp" + if not dispatch_header.exists(): + return FmhaSetupResult( + success=False, config=config, error="Dispatch header not generated" + ) + + # Step 2: Compile kernel .cpp AND ctypes in parallel + kernel_cpps = list(output_dir.glob("fmha_*.cpp")) + base_flags = fmha_compile_flags(config.gfx_arch, hipcc, family=config.family) + + compile_jobs = [] + for cpp in kernel_cpps: + obj = cpp.with_suffix(".o") + compile_jobs.append((base_flags + [str(cpp), "-o", str(obj)], obj, "kernel")) + + ctypes_obj = output_dir / "fmha_ctypes_lib.o" + ctypes_cmd = base_flags + [ + f"-I{output_dir}", + f"-I{output_dir / 'dispatcher_wrappers'}", + f"-include{dispatch_header}", + f'-DGFX_ARCH="{config.gfx_arch}"', + str(ctypes_src), + "-o", + str(ctypes_obj), + ] + compile_jobs.append((ctypes_cmd, ctypes_obj, "ctypes")) + + def _run_compile(job): + cmd, obj, label = job + if obj.exists(): + return (True, obj, label, "") + r = subprocess.run(cmd, capture_output=True, text=True) + return (r.returncode == 0, obj, label, r.stderr[:500]) + + with ThreadPoolExecutor(max_workers=len(compile_jobs)) as pool: + results = list(pool.map(_run_compile, compile_jobs)) + + kernel_objs = [] + for ok, obj, label, err in results: + if not ok: + return FmhaSetupResult( + success=False, + config=config, + error=f"{label} compile failed: {err}", + ) + if label == "kernel": + kernel_objs.append(str(obj)) + + # Step 3: Link + link_cmd = [ + hipcc, + "-shared", + "-fPIC", + str(ctypes_obj), + *kernel_objs, + str(static_lib), + "-o", + str(lib_path), + ] + r = subprocess.run(link_cmd, capture_output=True, text=True) + if r.returncode != 0: + return FmhaSetupResult( + success=False, config=config, error=f"Link failed: {r.stderr[:500]}" + ) + + # Step 4: Load + try: + runner = FmhaRunner.from_library(str(lib_path), config.gfx_arch) + except Exception as e: + return FmhaSetupResult(success=False, config=config, error=f"Load failed: {e}") + + elapsed = time.perf_counter() - t0 + return FmhaSetupResult( + success=True, + config=config, + runner=runner, + library_path=str(lib_path), + build_time_s=elapsed, + ) + + +def _run_compile_job(job): + """Module-level compile worker -- no threads, uses file-based stderr.""" + cmd, obj_str, name, label = job + if os.path.exists(obj_str): + return (name, True, "") + err_path = obj_str + ".err" + with open(err_path, "w") as ef: + rc = subprocess.call(cmd, stdout=subprocess.DEVNULL, stderr=ef) + if rc != 0: + try: + err = open(err_path).read()[:200] + except Exception: + err = f"rc={rc}" + return (name, False, err) + try: + os.unlink(err_path) + except OSError: + pass + return (name, True, "") + + +def setup_multiple_fmha_dispatchers( + configs: List[FmhaKernelConfig], + output_dir: Optional[Path] = None, + verbose: bool = False, + max_workers: Optional[int] = None, + executor=None, + progress_callback=None, +) -> List[FmhaSetupResult]: + """3-stage pipelined JIT: codegen(parallel) -> compile(parallel) -> link+load(parallel). + + Faster than calling setup_fmha_dispatcher() per-kernel because all hipcc + compile jobs (kernel + ctypes from ALL kernels) share one thread pool. + """ + if not configs: + return [] + + root = get_dispatcher_root() + codegen_dir = root / "codegen" + ctypes_src = root / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp" + static_lib = _find_static_lib() + hipcc = _find_hipcc() + arch = configs[0].gfx_arch + + if output_dir is None: + output_dir = root / "build" / "examples" + + results: dict[str, FmhaSetupResult] = {} + + # --- Stage 1: Codegen (sequential, skip cached) --- + def _codegen(cfg): + out = output_dir / f"fmha_jit_{cfg.name}" + lib_path = out / f"libdispatcher_fmha_{cfg.name}.so" + # Fast path: .so exists, register result and skip + if lib_path.exists(): + results[cfg.name] = FmhaSetupResult( + success=True, config=cfg, library_path=str(lib_path) + ) + return (cfg.name, cfg, out, True) + # Fast path: previous codegen already failed (no .hpp generated) + if out.exists() and not (out / "fmha_python_dispatch.hpp").exists(): + err_file = out / "_codegen_err.txt" + if err_file.exists(): + results[cfg.name] = FmhaSetupResult( + success=False, config=cfg, error="Codegen failed (cached)" + ) + return (cfg.name, cfg, out, False) + out.mkdir(parents=True, exist_ok=True) + # Check if codegen was already done (has .hpp but no .so yet) + if (out / "fmha_python_dispatch.hpp").exists(): + return (cfg.name, cfg, out, True) + if cfg.family == "bwd_dq_dk_dv": + dot = _make_bwd_dot_do_o_config(cfg) + config_json_str = json.dumps( + [ + json.loads(dot.to_codegen_json()), + json.loads(cfg.to_codegen_json()), + ] + ) + elif cfg.family == "fwd_splitkv": + comb = _make_splitkv_combine_config(cfg) + config_json_str = json.dumps( + [ + json.loads(cfg.to_codegen_json()), + json.loads(comb.to_codegen_json()), + ] + ) + else: + config_json_str = cfg.to_codegen_json() + err_file = out / "_codegen_err.txt" + with open(err_file, "w") as ef: + rc = subprocess.call( + [ + sys.executable, + str(codegen_dir / "fmha" / "generate_fallback.py"), + "--output-dir", + str(out), + "--gpu-target", + cfg.gfx_arch, + "--config-json", + config_json_str, + ], + stdout=subprocess.DEVNULL, + stderr=ef, + cwd=str(codegen_dir), + ) + ok = rc == 0 and (out / "fmha_python_dispatch.hpp").exists() + if not ok: + err_msg = err_file.read_text()[:200] if err_file.exists() else "unknown" + results[cfg.name] = FmhaSetupResult( + success=False, config=cfg, error=f"Codegen failed: {err_msg}" + ) + return (cfg.name, cfg, out, ok) + + codegen_results = [] + for i, cfg in enumerate(configs): + codegen_results.append(_codegen(cfg)) + if progress_callback: + progress_callback("codegen", i + 1, len(configs)) + + # --- Stage 2: Collect ALL compile jobs, run in one pool --- + # Use bwd family flag to get the superset of all flags (includes BWD-specific defines) + base_flags = fmha_compile_flags(arch, hipcc, family="bwd") + compile_jobs = [] # (cmd, obj_path, kernel_name, label) + + config_dirs: dict[str, tuple[FmhaKernelConfig, Path]] = {} + for name, cfg, out, ok in codegen_results: + if not ok or name in results: + continue + config_dirs[name] = (cfg, out) + for cpp in out.glob("fmha_*.cpp"): + obj = cpp.with_suffix(".o") + if not obj.exists(): + compile_jobs.append( + (base_flags + [str(cpp), "-o", str(obj)], str(obj), name, "kernel") + ) + ctypes_obj = out / "fmha_ctypes_lib.o" + if not ctypes_obj.exists(): + dispatch = out / "fmha_python_dispatch.hpp" + compile_jobs.append( + ( + base_flags + + [ + f"-I{out}", + f"-I{out / 'dispatcher_wrappers'}", + f"-include{dispatch}", + f'-DGFX_ARCH="{arch}"', + str(ctypes_src), + "-o", + str(ctypes_obj), + ], + str(ctypes_obj), + name, + "ctypes", + ) + ) + + failed_names: set = set() + + if compile_jobs: + _own_pool = None + _pool = executor + if _pool is None: + workers = max_workers or min(len(compile_jobs), os.cpu_count() or 4) + _own_pool = ProcessPoolExecutor(max_workers=workers) + _pool = _own_pool + try: + done_count = 0 + total_jobs = len(compile_jobs) + for name, ok, err in _pool.map(_run_compile_job, compile_jobs): + done_count += 1 + if progress_callback: + progress_callback("compile", done_count, total_jobs) + if not ok: + failed_names.add(name) + if name not in results: + cfg, _ = config_dirs[name] + results[name] = FmhaSetupResult( + success=False, config=cfg, error=f"Compile: {err}" + ) + finally: + if _own_pool is not None: + _own_pool.shutdown(wait=True) + + # --- Stage 3: Link (no GPU access -- runner loading deferred to caller) --- + def _link(item): + name, (cfg, out) = item + if name in failed_names or name in results: + return + objs = list(out.glob("*.o")) + lib_path = out / f"libdispatcher_fmha_{name}.so" + if not lib_path.exists(): + r = subprocess.run( + [ + hipcc, + "-shared", + "-fPIC", + *[str(o) for o in objs], + str(static_lib), + "-o", + str(lib_path), + ], + capture_output=True, + text=True, + ) + if r.returncode != 0: + results[name] = FmhaSetupResult( + success=False, config=cfg, error=f"Link: {r.stderr[:200]}" + ) + return + results[name] = FmhaSetupResult( + success=True, config=cfg, library_path=str(lib_path) + ) + + for item in config_dirs.items(): + _link(item) + + # Return in original order + return [ + results.get(c.name, FmhaSetupResult(success=False, config=c, error="skipped")) + for c in configs + ] + + +# ============================================================================= +# Registry (mirrors ctypes_utils.Registry) +# ============================================================================= + + +class FmhaRegistry: + """Kernel registry with parallel JIT build support.""" + + def __init__(self, name: str = "fmha"): + self._name = name + self._kernels: List[FmhaKernelConfig] = [] + + def register_kernel(self, config: FmhaKernelConfig): + self._kernels.append(config) + + def __len__(self): + return len(self._kernels) + + def build( + self, + verbose: bool = False, + max_workers: Optional[int] = None, + ) -> List[FmhaSetupResult]: + return setup_multiple_fmha_dispatchers( + self._kernels, + verbose=verbose, + max_workers=max_workers, + ) + + +# ============================================================================= +# Validator (mirrors ctypes_utils.Validator) +# ============================================================================= + + +class FmhaValidator: + """Validates FMHA GPU output against a reference. + + Usage: + validator = FmhaValidator(rtol=1e-2, atol=1e-2) + ok, max_abs, max_rel = validator.check(gpu_output, cpu_reference) + """ + + def __init__(self, rtol: float = 1e-2, atol: float = 1e-2): + self.rtol = rtol + self.atol = atol + + def check( + self, output: np.ndarray, reference: np.ndarray + ) -> Tuple[bool, float, float]: + """Check output against reference. + + Returns: + (is_valid, max_abs_error, max_rel_error) + """ + out_f32 = output.astype(np.float32) + ref_f32 = reference.astype(np.float32) + diff = np.abs(out_f32 - ref_f32) + max_abs = float(diff.max()) + max_rel = float((diff / (np.abs(ref_f32) + 1e-6)).max()) + ok = bool(np.allclose(out_f32, ref_f32, atol=self.atol, rtol=self.rtol)) + return ok, max_abs, max_rel + + +# ============================================================================= +# KernelSpec + spec_to_config (mirrors ctypes_utils.KernelSpec) +# ============================================================================= + + +@dataclass +class FmhaKernelSpec: + """High-level kernel specification for easy declaration. + + Mirrors GEMM's KernelSpec: specify name + key dimensions, get a + full FmhaKernelConfig via spec_to_config(). + """ + + name: str + hdim: int = 128 + pipeline: str = "qr_async" + # Stage 0 tile (Q*K^T) + tile_m0: int = 128 + tile_n0: int = 128 + tile_k0: int = 32 + + +def spec_to_config( + spec: FmhaKernelSpec, dtype: str = "fp16", arch: str = "gfx950" +) -> FmhaKernelConfig: + """Convert a high-level FmhaKernelSpec to a full FmhaKernelConfig.""" + hdim = spec.hdim + return FmhaKernelConfig( + data_type=dtype, + hdim_q=hdim, + hdim_v=hdim, + pipeline=spec.pipeline, + tile_m0=spec.tile_m0, + tile_n0=spec.tile_n0, + tile_k0=spec.tile_k0, + tile_n1=hdim, + tile_k1=spec.tile_k0, + tile_k0max=hdim, + gfx_arch=arch, + ) + + +# ============================================================================= +# Split-K heuristic (from fmhaarch.md Section 9.5) +# ============================================================================= diff --git a/dispatcher/scripts/example_kernel_builder.py b/dispatcher/scripts/example_kernel_builder.py index 20952cd91f..86336b8fa1 100755 --- a/dispatcher/scripts/example_kernel_builder.py +++ b/dispatcher/scripts/example_kernel_builder.py @@ -11,6 +11,7 @@ configuration parameters, and generates appropriate kernels. """ import argparse +import json import os import re import shutil @@ -156,6 +157,230 @@ def parse_conv_declarations(content: str) -> List[Dict]: return kernels +def parse_fmha_declarations(content: str) -> List[Dict]: + """Parse DECL_FMHA_KERNEL_SET declarations into config-json-ready dicts.""" + kernels = [] + + def parse_bool(value: str) -> bool: + return value.strip().lower() == "true" + + def parse_int_list(match_text: str) -> List[int]: + return [int(v.strip()) for v in match_text.split(",") if v.strip()] + + for match in re.finditer(r"DECL_FMHA_KERNEL_SET\s*\(", content): + body = extract_balanced_parens(content, match.end() - 1) + if not body: + continue + + for add_match in re.finditer(r"\.add\s*\(", body): + add_body = extract_balanced_parens(body, add_match.end() - 1) + if not add_body: + continue + + sig = { + "family": "fwd", + "data_type": "fp16", + "mode": "batch", + "vlayout": "r", + "hdim_q": 128, + "hdim_v": 128, + "mask": "no", + "bias": "no", + "lse": False, + "dropout": False, + "qscale": "no", + "rope": "none", + "logits": False, + "paged_kv": False, + "fp8_static_quant": False, + "skip_min_seqlen_q": False, + "sink": False, + "dbias": False, + "store_randval": False, + "deterministic": False, + "kv_memory_layout": "vectorized", + "kv_lookup_table": "sglang", + "page_size": 1, + } + profile = None + receipt = None + alg = { + "pipeline": "qr", + "tile": [128, 64, 32, 128, 32, 128], + "wave": [2, 2, 1, 2, 2, 1, 1, 1, 1], + "warp": [32, 32, 16, 32, 32, 16, 16, 16, 16], + "padding": [True, True, True, True], + "use_trload": False, + "hdim_q_alignment": 128, + "hdim_v_alignment": 128, + "block_per_cu": 1, + "num_wave_groups": 1, + "max_splits_log2": 0, + "max_seq_len_q": 0, + "selection_rank": 0, + "constraint_tag": "", + } + + if m := re.search(r'\.family\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["family"] = m.group(1) + if m := re.search(r'\.dtype\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["data_type"] = m.group(1) + if m := re.search(r'\.mode\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["mode"] = m.group(1) + if m := re.search(r'\.vlayout\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["vlayout"] = m.group(1) + if m := re.search(r"\.hdim\s*\(\s*(\d+)\s*(?:,\s*(\d+)\s*)?\)", add_body): + sig["hdim_q"] = int(m.group(1)) + sig["hdim_v"] = int(m.group(2)) if m.group(2) else int(m.group(1)) + if m := re.search(r'\.mask\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["mask"] = m.group(1) + if m := re.search(r'\.bias\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["bias"] = m.group(1) + if m := re.search(r"\.lse\s*\(\s*(true|false)\s*\)", add_body, re.I): + sig["lse"] = parse_bool(m.group(1)) + if m := re.search(r"\.dropout\s*\(\s*(true|false)\s*\)", add_body, re.I): + sig["dropout"] = parse_bool(m.group(1)) + if m := re.search(r'\.qscale\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["qscale"] = m.group(1) + if m := re.search(r'\.rope\s*\(\s*"([^"]+)"\s*\)', add_body): + sig["rope"] = m.group(1) + if m := re.search(r"\.logits\s*\(\s*(true|false)\s*\)", add_body, re.I): + sig["logits"] = parse_bool(m.group(1)) + if m := re.search(r"\.paged_kv\s*\(\s*(true|false)\s*\)", add_body, re.I): + sig["paged_kv"] = parse_bool(m.group(1)) + if m := re.search( + r"\.fp8_static_quant\s*\(\s*(true|false)\s*\)", add_body, re.I + ): + sig["fp8_static_quant"] = parse_bool(m.group(1)) + if m := re.search(r"\.skip\s*\(\s*(true|false)\s*\)", add_body, re.I): + sig["skip_min_seqlen_q"] = parse_bool(m.group(1)) + if m := re.search(r"\.sink\s*\(\s*(true|false)\s*\)", add_body, re.I): + sig["sink"] = parse_bool(m.group(1)) + if m := re.search(r"\.dbias\s*\(\s*(true|false)\s*\)", add_body, re.I): + sig["dbias"] = parse_bool(m.group(1)) + if m := re.search( + r"\.store_randval\s*\(\s*(true|false)\s*\)", add_body, re.I + ): + sig["store_randval"] = parse_bool(m.group(1)) + if m := re.search( + r"\.deterministic\s*\(\s*(true|false)\s*\)", add_body, re.I + ): + sig["deterministic"] = parse_bool(m.group(1)) + if m := re.search( + r'\.kv_cache\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*(?:,\s*(\d+)\s*)?\)', + add_body, + ): + sig["kv_memory_layout"] = m.group(1) + sig["kv_lookup_table"] = m.group(2) + sig["page_size"] = int(m.group(3)) if m.group(3) else 1 + if m := re.search(r'\.profile\s*\(\s*"([^"]+)"\s*\)', add_body): + profile = m.group(1) + if m := re.search(r"\.receipt\s*\(\s*(\d+)\s*\)", add_body): + receipt = int(m.group(1)) + + # Tile: bulk .tile(m0,n0,k0,n1,k1,k0max) or named .tile_m0(v)... + if m := re.search( + r"\.tile\s*\(\s*([0-9,\s]+)\)", + add_body, + ): + values = parse_int_list(m.group(1)) + if len(values) == 6: + alg["tile"] = values + for field_idx, field_name in enumerate( + ["tile_m0", "tile_n0", "tile_k0", "tile_n1", "tile_k1", "tile_k0max"] + ): + if m := re.search(rf"\.{field_name}\s*\(\s*(\d+)\s*\)", add_body): + alg["tile"][field_idx] = int(m.group(1)) + + # Wave: bulk .wave(m0,n0,k0,...) or named .wave_m0(v)... + if m := re.search(r"\.wave\s*\(\s*([0-9,\s]+)\)", add_body): + values = parse_int_list(m.group(1)) + if len(values) == 3: + values += [2, 2, 1, 1, 1, 1] + elif len(values) == 6: + values += [1, 1, 1] + if len(values) == 9: + alg["wave"] = values + for field_idx, field_name in enumerate( + [ + "wave_m0", + "wave_n0", + "wave_k0", + "wave_m1", + "wave_n1", + "wave_k1", + "wave_m2", + "wave_n2", + "wave_k2", + ] + ): + if m := re.search(rf"\.{field_name}\s*\(\s*(\d+)\s*\)", add_body): + alg["wave"][field_idx] = int(m.group(1)) + + # Warp: bulk .warp(m0,n0,k0,...) or named .warp_m0(v)... + if m := re.search(r"\.warp\s*\(\s*([0-9,\s]+)\)", add_body): + values = parse_int_list(m.group(1)) + if len(values) == 3: + values += [32, 32, 16, 16, 16, 16] + elif len(values) == 6: + values += [16, 16, 16] + if len(values) == 9: + alg["warp"] = values + for field_idx, field_name in enumerate( + [ + "warp_m0", + "warp_n0", + "warp_k0", + "warp_m1", + "warp_n1", + "warp_k1", + "warp_m2", + "warp_n2", + "warp_k2", + ] + ): + if m := re.search(rf"\.{field_name}\s*\(\s*(\d+)\s*\)", add_body): + alg["warp"][field_idx] = int(m.group(1)) + if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"\s*\)', add_body): + alg["pipeline"] = m.group(1) + if m := re.search( + r"\.padding\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)\s*\)", + add_body, + re.I, + ): + alg["padding"] = [parse_bool(m.group(i)) for i in range(1, 5)] + if m := re.search(r"\.trload\s*\(\s*(true|false)\s*\)", add_body, re.I): + alg["use_trload"] = parse_bool(m.group(1)) + if m := re.search(r"\.alignments\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", add_body): + alg["hdim_q_alignment"] = int(m.group(1)) + alg["hdim_v_alignment"] = int(m.group(2)) + if m := re.search(r"\.block_per_cu\s*\(\s*(\d+)\s*\)", add_body): + alg["block_per_cu"] = int(m.group(1)) + if m := re.search(r"\.num_wave_groups\s*\(\s*(\d+)\s*\)", add_body): + alg["num_wave_groups"] = int(m.group(1)) + if m := re.search(r"\.max_splits_log2\s*\(\s*(\d+)\s*\)", add_body): + alg["max_splits_log2"] = int(m.group(1)) + if m := re.search(r"\.max_seq_len_q\s*\(\s*(\d+)\s*\)", add_body): + alg["max_seq_len_q"] = int(m.group(1)) + if m := re.search(r"\.selection_rank\s*\(\s*(\d+)\s*\)", add_body): + alg["selection_rank"] = int(m.group(1)) + if m := re.search(r'\.constraint\s*\(\s*"([^"]+)"\s*\)', add_body): + alg["constraint_tag"] = m.group(1) + + arch = "gfx942" + if m := re.search(r'"(gfx\d+)"', add_body): + arch = m.group(1) + + entry = {"arch": arch, "signature": sig, "algorithm": alg} + if profile is not None: + entry["profile"] = profile + if receipt is not None: + entry["receipt"] = receipt + kernels.append(entry) + + return kernels + + def auto_fill_conv_defaults(kernel: Dict) -> Dict: """Auto-fill missing conv parameters with sensible defaults (autofill + autocorrect). @@ -619,7 +844,12 @@ def strip_cpp_strings_and_comments(content: str) -> str: n = len(content) # Patterns that indicate a string is problematic and should be stripped - problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("] + problematic_patterns = [ + "DECL_KERNEL_SET", + "DECL_GROUPED_CONV_KERNEL_SET", + "DECL_FMHA_KERNEL_SET", + ".add(", + ] while i < n: # Check for raw string literal: R"delimiter(...)delimiter" @@ -697,7 +927,9 @@ def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]: content = source_path.read_text() content = strip_cpp_strings_and_comments(content) - if "DECL_GROUPED_CONV_KERNEL_SET" in content: + if "DECL_FMHA_KERNEL_SET" in content: + return "fmha", parse_fmha_declarations(content) + elif "DECL_GROUPED_CONV_KERNEL_SET" in content: return "conv", parse_conv_declarations(content) elif "DECL_KERNEL_SET" in content: return "gemm", parse_gemm_declarations(content) @@ -1084,6 +1316,21 @@ def generate_conv_registration( return "\n".join(lines) +def generate_fmha_registration(wrapper_headers: List[Path], source_stem: str) -> str: + """Generate FMHA registration code using dispatcher wrapper factories.""" + if not wrapper_headers: + return " // No FMHA kernels to register" + + lines = [" (void)arch;", ""] + for header in sorted(wrapper_headers): + stem = header.stem.replace("dispatcher_wrapper_", "") + lines.append(f" // Register FMHA kernel: {stem}") + lines.append( + f" registry.register_kernel(ck_tile::dispatcher::generated::make_{stem}(arch));" + ) + return "\n".join(lines) + + def _build_conv_codegen_cmd( idx: int, k: Dict, codegen_dir: Path, output_dir: Path ) -> Tuple[int, List[str], str]: @@ -1161,6 +1408,87 @@ def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]: return (idx, True, "") +def _build_fmha_codegen_cmd( + idx: int, k: Dict, codegen_dir: Path, output_dir: Path, gpu_target: str +) -> Tuple[int, List[str], str]: + payload = { + "arch": k.get("arch", gpu_target), + "signature": k["signature"], + "algorithm": k["algorithm"], + } + if k.get("profile") is not None: + payload["profile"] = k["profile"] + if k.get("receipt") is not None: + payload["receipt"] = k["receipt"] + + config_json = json.dumps(payload) + cmd = [ + sys.executable, + str(codegen_dir / "fmha" / "codegen.py"), + "--output-dir", + str(output_dir), + "--gpu-target", + gpu_target, + "--config-json", + config_json, + ] + return (idx, cmd, str(codegen_dir)) + + +def _run_fmha_codegen(args: Tuple) -> Tuple[int, bool, str]: + idx, cmd, cwd = args + result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd) + if result.returncode != 0: + return (idx, False, result.stderr[:400] or result.stdout[:400]) + return (idx, True, "") + + +def generate_fmha_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path, gpu_target: str +) -> bool: + """Generate FMHA kernels for all declarations using unified FMHA codegen.""" + if not kernels: + return False + + # FMHA generator revisions can change emitted names or wrapper content. + # Clear previously generated FMHA files for this example directory so we + # only compile the current declaration set. + for pattern in ("fmha_*.hpp", "fmha_*.cpp", "fmha_*.o"): + for path in output_dir.glob(pattern): + path.unlink(missing_ok=True) + wrapper_dir = output_dir / "dispatcher_wrappers" + if wrapper_dir.exists(): + for path in wrapper_dir.glob("dispatcher_wrapper_fmha_*.hpp"): + path.unlink(missing_ok=True) + + unique_kernels = [] + seen = set() + for k in kernels: + key = json.dumps(k, sort_keys=True) + if key in seen: + continue + seen.add(key) + unique_kernels.append(k) + + work_items = [ + _build_fmha_codegen_cmd(idx, k, codegen_dir, output_dir, gpu_target) + for idx, k in enumerate(unique_kernels) + ] + + success_count = 0 + max_workers = min(len(work_items), os.cpu_count() or 4) + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_run_fmha_codegen, w): w[0] for w in work_items} + for future in as_completed(futures): + idx, ok, err = future.result() + if ok: + success_count += 1 + else: + print(f" FMHA codegen error for kernel {idx + 1}: {err}") + + return success_count > 0 + + def generate_conv_kernels( kernels: List[Dict], output_dir: Path, codegen_dir: Path ) -> bool: @@ -1290,19 +1618,10 @@ def compile_kernel(args: Tuple) -> Tuple[str, bool, str]: obj_file = output_dir / f"{kernel_name}.o" - cmd = [ - hipcc, - "-c", - "-fPIC", - "-std=c++17", - "-O3", - f"--offload-arch={gpu_target}", - "-mllvm", - "-enable-noalias-to-md-conversion=0", - "-Wno-undefined-func-template", - "-Wno-float-equal", - "--offload-compress", - ] + sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "python")) + from fmha_utils import fmha_compile_flags # noqa: E402 + + cmd = fmha_compile_flags(gpu_target, hipcc, family="bwd") for inc_dir in include_dirs: cmd.extend(["-I", str(inc_dir)]) @@ -1343,6 +1662,14 @@ def main(): print( f"[{target_name}] Conv {k.get('dtype', 'fp16')} {variant} {k.get('ndim', 2)}D ({len(kernels)} declarations)" ) + elif example_type == "fmha": + k = kernels[0] if kernels else {} + sig = k.get("signature", {}) + print( + f"[{target_name}] FMHA {sig.get('family', 'fwd')} {sig.get('data_type', 'fp16')} " + f"{sig.get('mode', 'batch')} hq={sig.get('hdim_q', 128)} hv={sig.get('hdim_v', 128)} " + f"({len(kernels)} declarations)" + ) elif example_type == "gemm": k = kernels[0] if kernels else {} print( @@ -1360,6 +1687,10 @@ def main(): print(f"[{target_name}] Generating kernels...") if example_type == "conv": success = generate_conv_kernels(kernels, args.output_dir, codegen_dir) + elif example_type == "fmha": + success = generate_fmha_kernels( + kernels, args.output_dir, codegen_dir, args.gpu_target + ) else: success = generate_gemm_kernels(kernels, args.output_dir, codegen_dir) @@ -1370,6 +1701,22 @@ def main(): # Find generated headers if example_type == "gemm": kernel_headers = list(args.output_dir.glob("gemm_*.hpp")) + wrapper_headers = list( + (args.output_dir / "dispatcher_wrappers").glob( + "dispatcher_wrapper_gemm_*.hpp" + ) + ) + elif example_type == "fmha": + kernel_headers = [ + h + for h in args.output_dir.glob("fmha_*.hpp") + if not h.name.startswith("dispatcher_wrapper_") + ] + wrapper_headers = list( + (args.output_dir / "dispatcher_wrappers").glob( + "dispatcher_wrapper_fmha_*.hpp" + ) + ) else: prefix_map = { "forward": "grouped_conv_fwd", @@ -1554,7 +1901,32 @@ inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, cons // Generic registration - avoids hardcoding the example name in user code // Safe for single-example executables (typical use case) #ifndef REGISTER_GENERATED_KERNELS -#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#define REGISTER_GENERATED_KERNELS(registry, arch) ::generated::{func_name}(registry, arch) +#endif +""" + elif example_type == "fmha": + wrapper_includes = "\n".join( + f'#include "dispatcher_wrappers/{h.name}"' for h in sorted(wrapper_headers) + ) + register_body = generate_fmha_registration(wrapper_headers, source_stem) + header_content = f"""// Auto-generated for {target_name} +#pragma once + +{wrapper_includes} + +#include "ck_tile/dispatcher/fmha_registry.hpp" +#include "ck_tile/dispatcher/fmha_dispatcher.hpp" + +namespace generated {{ + +inline void {func_name}(ck_tile::dispatcher::FmhaRegistry& registry, const std::string& arch) {{ +{register_body} +}} + +}} // namespace generated + +#ifndef REGISTER_GENERATED_KERNELS +#define REGISTER_GENERATED_KERNELS(registry, arch) ::generated::{func_name}(registry, arch) #endif """ else: @@ -1584,13 +1956,13 @@ inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::stri // Generic registration - avoids hardcoding the example name in user code // Safe for single-example executables (typical use case) #ifndef REGISTER_GENERATED_KERNELS -#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#define REGISTER_GENERATED_KERNELS(registry, arch) ::generated::{func_name}(registry, arch) #endif // Register a specific kernel set by name (for multi-registry patterns) // Usage: REGISTER_KERNEL_SET("compute_bound_set", registry, arch) #ifndef REGISTER_KERNEL_SET -#define REGISTER_KERNEL_SET(set_name, registry, arch) generated::register_kernel_set(set_name, registry, arch) +#define REGISTER_KERNEL_SET(set_name, registry, arch) ::generated::register_kernel_set(set_name, registry, arch) #endif """ header_path.write_text(header_content) diff --git a/dispatcher/scripts/parallel_kernel_builder.py b/dispatcher/scripts/parallel_kernel_builder.py index aef8f4ff0b..a0bb9089b4 100755 --- a/dispatcher/scripts/parallel_kernel_builder.py +++ b/dispatcher/scripts/parallel_kernel_builder.py @@ -32,7 +32,11 @@ def find_hipcc(): def compile_kernel(args): """Compile a single kernel.""" - kernel_hpp, output_dir, include_dirs, hipcc = args + if len(args) == 5: + kernel_hpp, output_dir, include_dirs, hipcc, arch = args + else: + kernel_hpp, output_dir, include_dirs, hipcc = args + arch = "gfx942" kernel_name = kernel_hpp.stem # Create wrapper .cpp @@ -45,19 +49,11 @@ namespace {{ volatile bool _k = true; }} # Compile to object obj_file = output_dir / f"{kernel_name}.o" - cmd = [ - hipcc, - "-c", - "-fPIC", - "-std=c++17", - "-O3", - "--offload-arch=gfx942", - "-mllvm", - "-enable-noalias-to-md-conversion=0", - "-Wno-undefined-func-template", - "-Wno-float-equal", - "--offload-compress", - ] + sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "python")) + from fmha_utils import fmha_compile_flags # noqa: E402 + + # arch is extracted from work tuple above + cmd = fmha_compile_flags(arch, hipcc, family="bwd") for inc_dir in include_dirs: cmd.extend(["-I", str(inc_dir)]) @@ -78,6 +74,12 @@ def main(): parser.add_argument("--output-dir", type=Path, required=True) parser.add_argument("--include-dirs", type=str, required=True) parser.add_argument("--jobs", type=int, default=os.cpu_count()) + parser.add_argument( + "--arch", + type=str, + default="gfx942", + help="GPU architecture target (default: gfx942)", + ) args = parser.parse_args() # Find kernel headers @@ -97,7 +99,9 @@ def main(): args.output_dir.mkdir(parents=True, exist_ok=True) # Prepare work items - work = [(h, args.output_dir, include_dirs, hipcc) for h in kernel_headers] + work = [ + (h, args.output_dir, include_dirs, hipcc, args.arch) for h in kernel_headers + ] # Compile in parallel obj_files = [] diff --git a/dispatcher/src/dispatcher.cpp b/dispatcher/src/dispatcher.cpp index 2cb589adf2..133485b248 100644 --- a/dispatcher/src/dispatcher.cpp +++ b/dispatcher/src/dispatcher.cpp @@ -65,6 +65,7 @@ float Dispatcher::run_fused(const void* a_ptr, throw NoKernelFound(oss.str()); } + kernel->set_benchmarking(benchmarking_); return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); } @@ -90,6 +91,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id, throw UnsupportedProblem(oss.str()); } + kernel->set_benchmarking(benchmarking_); return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); } diff --git a/dispatcher/src/fmha_dispatcher.cpp b/dispatcher/src/fmha_dispatcher.cpp new file mode 100644 index 0000000000..2685bb5f59 --- /dev/null +++ b/dispatcher/src/fmha_dispatcher.cpp @@ -0,0 +1,369 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/fmha_dispatcher.hpp" + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +FmhaDispatcher::FmhaDispatcher(FmhaRegistry* registry, const std::string& gfx_arch) + : registry_(registry ? registry : &FmhaRegistry::instance()), + heuristic_(nullptr), + strategy_(SelectionStrategy::FirstFit), + gfx_arch_(gfx_arch) +{ +} + +void FmhaDispatcher::set_heuristic(FmhaHeuristicFunction heuristic) +{ + heuristic_ = std::move(heuristic); + if(heuristic_) + { + strategy_ = SelectionStrategy::Heuristic; + } +} + +void FmhaDispatcher::set_strategy(SelectionStrategy strategy) { strategy_ = strategy; } + +void FmhaDispatcher::set_timing(int cold_niters, int nrepeat) +{ + cold_niters_ = cold_niters; + nrepeat_ = nrepeat; +} + +FmhaKernelInstancePtr FmhaDispatcher::select_kernel(const FmhaProblem& problem) const +{ + if(!problem.is_valid()) + { + return nullptr; + } + + switch(strategy_) + { + case SelectionStrategy::FirstFit: return select_first_fit(problem); + case SelectionStrategy::Heuristic: return select_heuristic(problem); + default: return nullptr; + } +} + +FmhaExecutionPlan FmhaDispatcher::plan_single_stage(const FmhaProblem& problem, + FmhaKernelFamily family) const +{ + FmhaExecutionPlan plan; + plan.api_family = problem.api_family; + + auto stage_problem = with_family(problem, family); + auto kernel = select_kernel(stage_problem); + if(kernel) + { + plan.stages.push_back({family, kernel->get_key().encode_identifier()}); + } + return plan; +} + +FmhaExecutionPlan FmhaDispatcher::plan(const FmhaProblem& problem) const +{ + switch(problem.api_family) + { + case FmhaApiFamily::Fwd: return plan_single_stage(problem, FmhaKernelFamily::Fwd); + case FmhaApiFamily::FwdPagedKv: return plan_single_stage(problem, FmhaKernelFamily::FwdPagedKv); + case FmhaApiFamily::FwdAppendKv: + return plan_single_stage(problem, FmhaKernelFamily::FwdAppendKv); + case FmhaApiFamily::BatchPrefill: + return plan_single_stage(problem, FmhaKernelFamily::BatchPrefill); + case FmhaApiFamily::FwdSplitKv: { + FmhaExecutionPlan plan; + plan.api_family = problem.api_family; + + auto split_problem = with_family(problem, FmhaKernelFamily::FwdSplitKv); + auto split_kernel = select_kernel(split_problem); + if(!split_kernel) + { + return plan; + } + + auto combine_problem = with_family(problem, FmhaKernelFamily::FwdSplitKvCombine); + auto combine_kernel = select_kernel(combine_problem); + if(!combine_kernel) + { + return {}; + } + + plan.stages.push_back( + {FmhaKernelFamily::FwdSplitKv, split_kernel->get_key().encode_identifier()}); + plan.stages.push_back( + {FmhaKernelFamily::FwdSplitKvCombine, combine_kernel->get_key().encode_identifier()}); + return plan; + } + case FmhaApiFamily::Bwd: { + FmhaExecutionPlan plan; + plan.api_family = problem.api_family; + + auto dot_problem = with_family(problem, FmhaKernelFamily::BwdDotDoO); + auto dot_kernel = select_kernel(dot_problem); + if(!dot_kernel) + { + return plan; + } + + auto dq_problem = with_family(problem, FmhaKernelFamily::BwdDqDkDv); + auto dq_kernel = select_kernel(dq_problem); + if(!dq_kernel) + { + return {}; + } + + plan.stages.push_back( + {FmhaKernelFamily::BwdDotDoO, dot_kernel->get_key().encode_identifier()}); + plan.stages.push_back( + {FmhaKernelFamily::BwdDqDkDv, dq_kernel->get_key().encode_identifier()}); + + auto convert_problem = with_family(problem, FmhaKernelFamily::BwdConvertDq); + auto convert_kernel = select_kernel(convert_problem); + if(convert_kernel) + { + plan.stages.push_back( + {FmhaKernelFamily::BwdConvertDq, convert_kernel->get_key().encode_identifier()}); + } + return plan; + } + default: return {}; + } +} + +float FmhaDispatcher::run(const FmhaInvocation& invocation, void* stream) const +{ + auto problem = FmhaProblem::from_invocation(invocation, gfx_arch_); + auto exec = plan(problem); + if(!exec.is_valid()) + { + std::ostringstream oss; + oss << "No suitable FMHA execution plan for API family " << to_string(problem.api_family) + << " and dtype " << problem.data_type; + throw NoKernelFound(oss.str()); + } + + return run_plan(exec, invocation, stream); +} + +float FmhaDispatcher::run_explicit(const std::string& kernel_id, + const FmhaInvocation& invocation, + void* stream) const +{ + auto kernel = registry_->lookup(kernel_id); + if(!kernel) + { + throw NoKernelFound("FMHA kernel not found: " + kernel_id); + } + auto sc = make_stream_config(stream); + return kernel->run(invocation, sc); +} + +float FmhaDispatcher::run_fwd(fmha_fwd_traits traits, fmha_fwd_args args, void* stream) const +{ + return run(FmhaInvocation::make(std::move(traits), std::move(args)), stream); +} + +float FmhaDispatcher::run_fwd_pagedkv(fmha_fwd_pagedkv_traits traits, + fmha_fwd_pagedkv_args args, + void* stream) const +{ + return run(FmhaInvocation::make(std::move(traits), std::move(args)), stream); +} + +float FmhaDispatcher::run_fwd_splitkv(fmha_fwd_splitkv_traits traits, + fmha_fwd_splitkv_args args, + void* stream) const +{ + return run(FmhaInvocation::make(std::move(traits), std::move(args)), stream); +} + +float FmhaDispatcher::run_fwd_appendkv(fmha_fwd_appendkv_traits traits, + fmha_fwd_appendkv_args args, + void* stream) const +{ + return run(FmhaInvocation::make(std::move(traits), std::move(args)), stream); +} + +float FmhaDispatcher::run_batch_prefill(fmha_batch_prefill_traits traits, + fmha_batch_prefill_args args, + void* stream) const +{ + return run(FmhaInvocation::make(std::move(traits), std::move(args)), stream); +} + +float FmhaDispatcher::run_bwd(fmha_bwd_traits traits, fmha_bwd_args args, void* stream) const +{ + return run(FmhaInvocation::make(std::move(traits), std::move(args)), stream); +} + +FmhaKernelInstancePtr FmhaDispatcher::select_first_fit(const FmhaProblem& problem) const +{ + // Seqtune-aware selection per fmhaarch.md Section 7.3.3: + // 1. For short sequences (seqlen_q <= tile_m0): prefer smallest fitting tile + // 2. tile_m0 == 64: unconditional fallback + // 3. Prefer unpadded over padded + // 4. Within same category: selection_rank, then smaller tile_m0 + + auto kernels = registry_->get_all(); + const auto max_sq = problem.effective_max_seqlen_q(); + + // Find max tile_m0 across all compatible kernels + int max_tile_m0_all = 0; + for(const auto& kernel : kernels) + { + if(kernel->supports(problem)) + { + max_tile_m0_all = std::max(max_tile_m0_all, + static_cast(kernel->get_key().algorithm.tile_shape.m0)); + } + } + + FmhaKernelInstancePtr best = nullptr; + std::tuple best_score = {std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max()}; + + for(const auto& kernel : kernels) + { + if(!kernel->supports(problem)) + continue; + + const auto& key = kernel->get_key(); + int tile_m0 = key.algorithm.tile_shape.m0; + int rank = key.algorithm.selection_rank; + bool aligned = (tile_m0 > 0) && (max_sq > 0) && (max_sq % tile_m0 == 0); + + // Seqtune scoring (lower tuple is better): + // Category 0: seqlen_q <= tile_m0 AND aligned (perfect fit, smallest tile wins) + // Category 1: tile_m0 == 64 (unconditional fallback) + // Category 2: tile_m0 == max_tile_m0 (catch-all) + // Category 3: aligned (no padding needed) + // Category 4: needs padding (last resort) + int category; + if(tile_m0 > 0 && max_sq <= tile_m0 && aligned) + category = 0; + else if(tile_m0 == 64) + category = 1; + else if(tile_m0 == max_tile_m0_all) + category = 2; + else if(aligned) + category = 3; + else + category = 4; + + auto score = std::make_tuple(category, rank, tile_m0); + + if(score < best_score) + { + best = kernel; + best_score = score; + } + } + + return best; +} + +FmhaKernelInstancePtr FmhaDispatcher::select_heuristic(const FmhaProblem& problem) const +{ + if(!heuristic_) + { + return select_first_fit(problem); + } + + for(const auto& kernel_id : heuristic_(problem)) + { + auto kernel = registry_->lookup(kernel_id); + if(kernel && kernel->supports(problem)) + { + return kernel; + } + } + + return select_first_fit(problem); +} + +FmhaProblem FmhaDispatcher::with_family(const FmhaProblem& base, FmhaKernelFamily family) const +{ + auto copy = base; + copy.requested_family = family; + return copy; +} + +float FmhaDispatcher::run_plan(const FmhaExecutionPlan& plan, + const FmhaInvocation& invocation, + void* stream) const +{ + auto sc = make_stream_config(stream); + + if(plan.stages.size() == 1) + { + auto kernel = registry_->lookup(plan.stages.front().kernel_id); + if(!kernel) + { + throw NoKernelFound("Missing FMHA kernel: " + plan.stages.front().kernel_id); + } + return kernel->run(invocation, sc); + } + + // Multi-stage lambdas capture by reference. This is safe because + // launch_kernel dispatches all stages on the same HIP stream before + // returning. If launch_kernel ever becomes async, these must capture + // by value or use shared_ptr. + if(plan.stages.size() == 2) + { + auto first = registry_->lookup(plan.stages[0].kernel_id); + auto second = registry_->lookup(plan.stages[1].kernel_id); + if(!first || !second) + { + throw NoKernelFound("Missing FMHA kernel in two-stage plan"); + } + + return ck_tile::launch_kernel( + sc, + [&](const ck_tile::stream_config& inner) { first->launch(invocation, inner); }, + [&](const ck_tile::stream_config& inner) { second->launch(invocation, inner); }); + } + + if(plan.stages.size() == 3) + { + auto first = registry_->lookup(plan.stages[0].kernel_id); + auto second = registry_->lookup(plan.stages[1].kernel_id); + auto third = registry_->lookup(plan.stages[2].kernel_id); + if(!first || !second || !third) + { + throw NoKernelFound("Missing FMHA kernel in three-stage plan"); + } + + return ck_tile::launch_kernel( + sc, + [&](const ck_tile::stream_config& inner) { first->launch(invocation, inner); }, + [&](const ck_tile::stream_config& inner) { second->launch(invocation, inner); }, + [&](const ck_tile::stream_config& inner) { third->launch(invocation, inner); }); + } + + throw std::runtime_error("Unsupported FMHA execution plan length"); +} + +ck_tile::stream_config FmhaDispatcher::make_stream_config(void* stream) const +{ + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = benchmarking_enabled_; + sc.log_level_ = 0; + sc.cold_niters_ = benchmarking_enabled_ ? cold_niters_ : 0; + sc.nrepeat_ = benchmarking_enabled_ ? nrepeat_ : 1; + sc.is_gpu_timer_ = benchmarking_enabled_; + sc.flush_cache_ = false; + sc.rotating_count_ = 1; + return sc; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/src/fmha_registry.cpp b/dispatcher/src/fmha_registry.cpp new file mode 100644 index 0000000000..0457c33e64 --- /dev/null +++ b/dispatcher/src/fmha_registry.cpp @@ -0,0 +1,302 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/fmha_registry.hpp" + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +namespace { + +std::string json_escape(const std::string& str) +{ + std::ostringstream oss; + for(unsigned char c : str) + { + switch(c) + { + case '"': oss << "\\\""; break; + case '\\': oss << "\\\\"; break; + case '\b': oss << "\\b"; break; + case '\f': oss << "\\f"; break; + case '\n': oss << "\\n"; break; + case '\r': oss << "\\r"; break; + case '\t': oss << "\\t"; break; + default: + if(c < 0x20) + { + char buf[8]; + std::snprintf(buf, sizeof(buf), "\\u%04x", c); + oss << buf; + } + else + { + oss << static_cast(c); + } + break; + } + } + return oss.str(); +} + +} // namespace + +bool FmhaRegistry::register_kernel(FmhaKernelInstancePtr instance, Priority priority) +{ + if(!instance) + { + return false; + } + bool ok = Base::register_kernel( + instance->get_key().encode_identifier(), std::move(instance), priority); + if(ok) + { + perform_auto_export(); + } + return ok; +} + +FmhaKernelInstancePtr FmhaRegistry::lookup(const std::string& identifier) const +{ + std::lock_guard lock(mutex()); + auto it = entries().find(identifier); + return it != entries().end() ? it->second.instance : nullptr; +} + +FmhaKernelInstancePtr FmhaRegistry::lookup(const FmhaKernelKey& key) const +{ + return lookup(key.encode_identifier()); +} + +std::vector FmhaRegistry::get_all() const +{ + std::lock_guard lock(mutex()); + + struct RankedKernel + { + FmhaKernelInstancePtr instance; + Priority priority; + }; + + std::vector ranked; + ranked.reserve(entries().size()); + for(const auto& [name, entry] : entries()) + { + ranked.push_back({entry.instance, entry.priority}); + } + + std::stable_sort( + ranked.begin(), ranked.end(), [](const RankedKernel& lhs, const RankedKernel& rhs) { + if(lhs.priority != rhs.priority) + { + return static_cast(lhs.priority) > static_cast(rhs.priority); + } + + const auto& lkey = lhs.instance->get_key(); + const auto& rkey = rhs.instance->get_key(); + if(lkey.algorithm.selection_rank != rkey.algorithm.selection_rank) + { + return lkey.algorithm.selection_rank < rkey.algorithm.selection_rank; + } + + if(lkey.signature.hdim_q != rkey.signature.hdim_q) + { + return lkey.signature.hdim_q < rkey.signature.hdim_q; + } + + if(lkey.signature.hdim_v != rkey.signature.hdim_v) + { + return lkey.signature.hdim_v < rkey.signature.hdim_v; + } + + if(lkey.algorithm.tile_shape.m0 != rkey.algorithm.tile_shape.m0) + { + return lkey.algorithm.tile_shape.m0 < rkey.algorithm.tile_shape.m0; + } + + return lhs.instance->get_name() < rhs.instance->get_name(); + }); + + std::vector result; + result.reserve(ranked.size()); + for(const auto& entry : ranked) + { + result.push_back(entry.instance); + } + return result; +} + +std::vector +FmhaRegistry::filter(std::function predicate) const +{ + auto all = get_all(); + std::vector result; + result.reserve(all.size()); + for(const auto& instance : all) + { + if(predicate(*instance)) + { + result.push_back(instance); + } + } + return result; +} + +std::string FmhaRegistry::export_json(bool include_statistics) const +{ + auto all = get_all(); + + std::ostringstream json; + json << "{\n"; + json << " \"metadata\": {\n"; + json << " \"registry_name\": \"" << json_escape(get_name()) << "\",\n"; + json << " \"total_kernels\": " << all.size() << "\n"; + json << " }"; + + if(include_statistics) + { + std::map by_family; + std::map by_dtype; + std::map by_pipeline; + + for(const auto& kernel : all) + { + const auto& key = kernel->get_key(); + by_family[to_string(key.signature.family)]++; + by_dtype[key.signature.data_type]++; + by_pipeline[key.algorithm.pipeline]++; + } + + json << ",\n \"statistics\": {\n"; + auto emit_map = [&](const char* label, const auto& values, bool last) { + json << " \"" << label << "\": {"; + bool first = true; + for(const auto& [name, count] : values) + { + if(!first) + { + json << ","; + } + json << "\"" << json_escape(name) << "\":" << count; + first = false; + } + json << "}"; + json << (last ? "\n" : ",\n"); + }; + + emit_map("by_family", by_family, false); + emit_map("by_dtype", by_dtype, false); + emit_map("by_pipeline", by_pipeline, true); + json << " }"; + } + + json << ",\n \"kernels\": [\n"; + for(std::size_t i = 0; i < all.size(); ++i) + { + const auto& kernel = all[i]; + const auto& key = kernel->get_key(); + json << " {\n"; + json << " \"name\": \"" << json_escape(kernel->get_name()) << "\",\n"; + json << " \"identifier\": \"" << json_escape(key.encode_identifier()) << "\",\n"; + json << " \"family\": \"" << to_string(key.signature.family) << "\",\n"; + json << " \"dtype\": \"" << json_escape(key.signature.data_type) << "\",\n"; + json << " \"pipeline\": \"" << json_escape(key.algorithm.pipeline) << "\",\n"; + json << " \"gfx_arch\": \"" << json_escape(key.gfx_arch) << "\"\n"; + json << " }"; + if(i + 1 < all.size()) + { + json << ","; + } + json << "\n"; + } + json << " ]\n"; + json << "}\n"; + return json.str(); +} + +bool FmhaRegistry::export_json_to_file(const std::string& filename, bool include_statistics) const +{ + std::ofstream file(filename); + if(!file.is_open()) + { + return false; + } + file << export_json(include_statistics); + return true; +} + +std::size_t FmhaRegistry::filter_by_arch(const std::string& gpu_arch) +{ + std::lock_guard lock(mutex()); + + std::vector to_remove; + for(const auto& [name, entry] : entries()) + { + const auto& arch = entry.instance->get_key().gfx_arch; + if(!arch.empty() && arch != gpu_arch) + { + to_remove.push_back(name); + } + } + + for(const auto& name : to_remove) + { + entries_mut().erase(name); + } + + return to_remove.size(); +} + +std::size_t FmhaRegistry::filter_by_receipt(int receipt_id) +{ + std::lock_guard lock(mutex()); + std::vector to_remove; + for(const auto& [name, entry] : entries()) + { + if(entry.instance) + { + int r = entry.instance->get_key().signature.receipt; + if(r >= 0 && r != receipt_id) + { + to_remove.push_back(name); + } + } + } + for(const auto& name : to_remove) + { + entries_mut().erase(name); + } + return to_remove.size(); +} + +std::vector FmhaRegistry::available_receipts() const +{ + std::lock_guard lock(mutex()); + std::set receipts; + for(const auto& [name, entry] : entries()) + { + if(entry.instance) + { + int r = entry.instance->get_key().signature.receipt; + if(r >= 0) + receipts.insert(r); + } + } + return {receipts.begin(), receipts.end()}; +} + +FmhaRegistry& FmhaRegistry::instance() +{ + static FmhaRegistry registry; + return registry; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/tests/CMakeLists.txt b/dispatcher/tests/CMakeLists.txt index a54feba284..a18663f76d 100644 --- a/dispatcher/tests/CMakeLists.txt +++ b/dispatcher/tests/CMakeLists.txt @@ -89,6 +89,43 @@ set_tests_properties(dispatcher_test_arch_support PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" ) +add_test( + NAME dispatcher_test_fmha_codegen + COMMAND ${Python3_EXECUTABLE} -m unittest test_fmha_codegen -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_fmha_codegen PROPERTIES + LABELS "dispatcher;python;fmha;codegen" + TIMEOUT 120 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_fmha_rules + COMMAND ${Python3_EXECUTABLE} -m unittest test_fmha_rules -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_fmha_rules PROPERTIES + LABELS "dispatcher;python;fmha;rules" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# FMHA parity test (requires GPU) +add_test( + NAME dispatcher_test_fmha_parity + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_fmha_parity.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_test_fmha_parity PROPERTIES + LABELS "dispatcher;python;fmha;parity;gpu" + TIMEOUT 600 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + # Stress Test Script add_test( NAME dispatcher_stress_test @@ -180,6 +217,9 @@ set(TEST_SOURCES test_registry.cpp test_dispatcher.cpp test_tile_backend.cpp + test_fmha_problem.cpp + test_fmha_dispatcher.cpp + test_fmha_registry.cpp # Extended unit tests (more comprehensive coverage) test_kernel_key_extended.cpp @@ -221,6 +261,7 @@ set(STANDALONE_TESTS test_grouped_conv_problem.cpp test_grouped_conv_kernel_decl.cpp test_grouped_conv_registry.cpp + test_fmha_kernel_decl.cpp ) foreach(test_source ${STANDALONE_TESTS}) diff --git a/dispatcher/tests/fmha_smoke_matrix.py b/dispatcher/tests/fmha_smoke_matrix.py new file mode 100644 index 0000000000..e6408d1da1 --- /dev/null +++ b/dispatcher/tests/fmha_smoke_matrix.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +FMHA smoke test matrix generator. + +Generates the same test cases as smoke_test_fwd.sh and smoke_test_bwd.sh +from the CK Tile 01_fmha example, for automated parity testing. +""" + +from dataclasses import dataclass +from typing import List, Set, Tuple + + +@dataclass +class TestCase: + name: str = "" + direction: str = "fwd" + prec: str = "fp16" + mode: int = 0 + batch: int = 2 + nhead_q: int = 2 + nhead_k: int = -1 + hdim_q: int = 128 + hdim_v: int = -1 + seqlen_q: int = 128 + seqlen_k: int = 128 + bias: str = "n" + mask: str = "0" + lse: int = 0 + p_drop: float = 0.0 + perm: int = 1 + num_splits: int = 1 + page_block_size: int = 0 + cache_batch_idx: int = 0 + s_kpad: str = "" + q_eff_lens: str = "" + kv_eff_lens: str = "" + s_qpad: str = "" + rotary_dim: int = 0 + rotary_interleaved: int = 0 + deterministic: int = 0 + dbias: int = 0 + + def effective_nhead_k(self): + return self.nhead_k if self.nhead_k > 0 else self.nhead_q + + def effective_hdim_v(self): + return self.hdim_v if self.hdim_v > 0 else self.hdim_q + + +def generate_fwd_fp16_bf16_matrix() -> List[TestCase]: + """Generate the run_fp16_bf16_tests matrix from smoke_test_fwd.sh.""" + cases = [] + idx = 0 + for prec in ["fp16", "bf16"]: + for mode in [1, 0]: + for perm in [0, 1]: + for hdim in [32, 64, 128, 256]: + for lse in [0, 1]: + for bias in ["n", "e", "a"]: + for p_drop in [0.0, 0.2]: + base = dict( + prec=prec, + mode=mode, + perm=perm, + lse=lse, + bias=bias, + p_drop=p_drop, + ) + subcases = [ + dict( + batch=2, + nhead_q=2, + nhead_k=1, + hdim_q=16, + hdim_v=hdim, + seqlen_q=55, + seqlen_k=256, + mask="0", + ), + dict( + batch=1, + nhead_q=3, + hdim_q=hdim, + seqlen_q=100, + seqlen_k=51, + mask="0", + ), + dict( + batch=2, + nhead_q=1, + hdim_q=16, + hdim_v=hdim, + seqlen_q=99, + seqlen_k=256, + mask="1", + ), + dict( + batch=1, + nhead_q=2, + nhead_k=1, + hdim_q=hdim, + seqlen_q=1024, + seqlen_k=256, + mask="2", + ), + dict( + batch=2, + nhead_q=1, + hdim_q=hdim, + hdim_v=24, + seqlen_q=3, + seqlen_k=99, + mask="2", + ), + dict( + batch=3, + nhead_q=2, + nhead_k=1, + hdim_q=hdim, + seqlen_q=200, + seqlen_k=520, + mask="t:128,30", + ), + dict( + batch=2, + nhead_q=1, + hdim_q=hdim, + seqlen_q=99, + seqlen_k=32, + mask="b:4,35", + ), + dict( + batch=1, + nhead_q=2, + nhead_k=1, + hdim_q=hdim, + seqlen_q=33, + seqlen_k=0, + mask="2", + ), + dict( + batch=1, + nhead_q=2, + nhead_k=1, + hdim_q=hdim, + seqlen_q=1, + seqlen_k=10, + s_kpad="32", + mask="2", + ), + ] + for sc in subcases: + idx += 1 + c = TestCase( + name=f"fwd_{idx:04d}_{prec}_m{mode}_h{hdim}", + direction="fwd", + **base, + **sc, + ) + cases.append(c) + return cases + + +def generate_bwd_matrix() -> List[TestCase]: + """Generate the bwd smoke test matrix from smoke_test_bwd.sh.""" + cases = [] + idx = 0 + base_shapes = [ + dict(batch=1, nhead_q=4, nhead_k=2, seqlen_q=259, seqlen_k=259, mask="0"), + dict(batch=2, nhead_q=2, seqlen_q=516, seqlen_k=253, mask="0"), + dict(batch=1, nhead_q=4, nhead_k=1, seqlen_q=500, seqlen_k=251, mask="1"), + dict(batch=1, nhead_q=2, seqlen_q=900, seqlen_k=258, mask="2"), + dict(batch=2, nhead_q=1, seqlen_q=987, seqlen_k=219, mask="t:128,30"), + dict(batch=2, nhead_q=3, nhead_k=1, seqlen_q=244, seqlen_k=499, mask="b:4,35"), + ] + for prec in ["fp16", "bf16"]: + for perm in [0, 1]: + for hdim in [32, 64, 128, 256]: + for mode in [0, 1]: + for bias in ["n", "a"]: + for p_drop in [0.0, 0.2]: + for shape in base_shapes: + idx += 1 + cases.append( + TestCase( + name=f"bwd_{idx:04d}_{prec}_h{hdim}", + direction="bwd", + prec=prec, + mode=mode, + perm=perm, + hdim_q=hdim, + hdim_v=hdim, + bias=bias, + p_drop=p_drop, + lse=1, + **shape, + ) + ) + return cases + + +def generate_splitkv_matrix() -> List[TestCase]: + """Generate the splitkv smoke test matrix (same subcases as fwd, with num_splits > 1).""" + cases = [] + idx = 0 + for prec in ["fp16", "bf16"]: + for mode in [0]: # splitkv only supports batch mode in smoke test + for perm in [0, 1]: + for hdim in [64, 128, 256]: + for num_splits in [2, 3]: + for bias in ["n"]: + subcases = [ + dict( + batch=2, + nhead_q=2, + nhead_k=1, + seqlen_q=55, + seqlen_k=256, + mask="0", + ), + dict( + batch=1, + nhead_q=3, + seqlen_q=100, + seqlen_k=51, + mask="0", + ), + dict( + batch=1, + nhead_q=2, + nhead_k=1, + seqlen_q=1024, + seqlen_k=256, + mask="2", + ), + dict( + batch=3, + nhead_q=2, + nhead_k=1, + seqlen_q=200, + seqlen_k=520, + mask="t:128,30", + ), + ] + for sc in subcases: + idx += 1 + cases.append( + TestCase( + name=f"splitkv_{idx:04d}_{prec}_h{hdim}_s{num_splits}", + direction="fwd_splitkv", + prec=prec, + mode=mode, + perm=perm, + hdim_q=hdim, + hdim_v=hdim, + lse=1, + bias=bias, + p_drop=0.0, + num_splits=num_splits, + page_block_size=128, + cache_batch_idx=1, + **sc, + ) + ) + return cases + + +def generate_padding_matrix() -> List[TestCase]: + """Generate padding edge-case test cases.""" + cases = [] + idx = 0 + for prec in ["fp16"]: + for hdim in [32, 64, 128]: + edge_shapes = [ + dict(batch=1, nhead_q=1, seqlen_q=1, seqlen_k=1, mask="0"), + dict(batch=1, nhead_q=1, seqlen_q=1, seqlen_k=256, mask="0"), + dict(batch=1, nhead_q=1, seqlen_q=255, seqlen_k=1, mask="0"), + dict(batch=1, nhead_q=2, seqlen_q=3, seqlen_k=5, mask="1"), + dict(batch=2, nhead_q=1, seqlen_q=17, seqlen_k=33, mask="2"), + ] + for shape in edge_shapes: + idx += 1 + cases.append( + TestCase( + name=f"pad_{idx:04d}_{prec}_h{hdim}", + direction="fwd", + prec=prec, + mode=0, + perm=1, + hdim_q=hdim, + hdim_v=hdim, + bias="n", + lse=0, + p_drop=0.0, + **shape, + ) + ) + return cases + + +def generate_fp8_matrix() -> List[TestCase]: + """Generate fp8 smoke test cases (fp8bf16 and fp8fp32).""" + cases = [] + idx = 0 + for prec in ["fp8bf16"]: + for mode in [0]: + for perm in [1]: + for hdim in [64, 128]: + for mask in ["0", "2"]: + idx += 1 + cases.append( + TestCase( + name=f"fp8_{idx:04d}_{prec}_h{hdim}", + direction="fwd", + prec=prec, + mode=mode, + perm=perm, + hdim_q=hdim, + hdim_v=hdim, + batch=2, + nhead_q=4, + nhead_k=4, + seqlen_q=128, + seqlen_k=128, + bias="n", + mask=mask, + lse=0, + p_drop=0.0, + ) + ) + return cases + + +def unique_kernel_configs(cases: List[TestCase]) -> Set[Tuple]: + """Extract unique kernel configs needed to run the test cases.""" + configs = set() + for c in cases: + dv = c.effective_hdim_v() + mask_cat = ( + "no" if c.mask == "0" else ("causal" if c.mask in ["1", "2"] else "window") + ) + bias_cat = c.bias + configs.add( + ( + c.direction, + c.prec, + c.hdim_q, + dv, + mask_cat, + bias_cat, + bool(c.lse), + c.p_drop > 0, + ) + ) + return configs + + +def to_ck_cli_args(case: TestCase) -> list: + """Convert a TestCase to CK Tile CLI arguments.""" + nk = case.effective_nhead_k() + dv = case.effective_hdim_v() + args = [ + f"-prec={case.prec}", + f"-mode={case.mode}", + f"-b={case.batch}", + f"-h={case.nhead_q}", + ] + if nk != case.nhead_q: + args.append(f"-h_k={nk}") + args += [f"-d={case.hdim_q}"] + if dv != case.hdim_q: + args.append(f"-d_v={dv}") + args += [ + f"-s={case.seqlen_q}", + f"-s_k={case.seqlen_k}", + f"-bias={case.bias}", + f"-mask={case.mask}", + f"-lse={case.lse}", + f"-p_drop={case.p_drop}", + f"-iperm={case.perm}", + f"-operm={case.perm}", + "-v=1", + "-warmup=0", + "-repeat=1", + ] + if case.s_kpad: + args.append(f"-s_kpad={case.s_kpad}") + if case.num_splits > 1: + args.append(f"-num_splits={case.num_splits}") + if case.page_block_size > 0: + args.append(f"-page_block_size={case.page_block_size}") + if case.cache_batch_idx: + args.append(f"-cache_batch_idx={case.cache_batch_idx}") + return args + + +if __name__ == "__main__": + fwd = generate_fwd_fp16_bf16_matrix() + bwd = generate_bwd_matrix() + skv = generate_splitkv_matrix() + pad = generate_padding_matrix() + fp8 = generate_fp8_matrix() + + all_cases = fwd + bwd + skv + pad + fp8 + all_configs = unique_kernel_configs(all_cases) + + print(f"Forward: {len(fwd):5d} cases") + print(f"Backward: {len(bwd):5d} cases") + print(f"SplitKV: {len(skv):5d} cases") + print(f"Padding: {len(pad):5d} cases") + print(f"FP8: {len(fp8):5d} cases") + print(f"Total: {len(all_cases):5d} cases, {len(all_configs)} unique configs") diff --git a/dispatcher/tests/full_parity_test.py b/dispatcher/tests/full_parity_test.py new file mode 100644 index 0000000000..cc5b3032a7 --- /dev/null +++ b/dispatcher/tests/full_parity_test.py @@ -0,0 +1,1020 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Full FMHA Parity Test -- parallel JIT build, sequential GPU test. + +Phase 1: JIT-compile every unique kernel config in parallel (hipcc only, no GPU). +Phase 2: Run each test case sequentially through CK Tile and the dispatcher + (each dispatcher invocation in its own subprocess for HIP isolation). + +Usage: + python3 full_parity_test.py --max-cases 100 + python3 full_parity_test.py --max-cases 0 # all ~3500 cases + python3 full_parity_test.py --workers 8 # parallel JIT build + python3 full_parity_test.py --skip-jit # reuse previous build +""" + +import sys +import os +import time +import argparse +import subprocess +import json +from pathlib import Path +from collections import Counter +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional, Dict, Tuple +from fmha_smoke_matrix import ( + generate_fwd_fp16_bf16_matrix, + generate_bwd_matrix, + generate_splitkv_matrix, + generate_padding_matrix, + generate_fp8_matrix, + to_ck_cli_args, + TestCase, +) + +SCRIPT_DIR = Path(__file__).resolve().parent +DISPATCHER_DIR = SCRIPT_DIR.parent +PYTHON_DIR = DISPATCHER_DIR / "python" + +sys.path.insert(0, str(SCRIPT_DIR)) + + +# ========================================================================= +# Config dedup + tile lookup +# ========================================================================= + +HDIM_TILE_TABLE = { + (32, 32): (128, 64, 16, 32, 32, 32), + (64, 64): (128, 64, 32, 64, 32, 64), + (128, 128): (128, 128, 32, 128, 32, 128), + (192, 128): (128, 128, 32, 128, 32, 192), + (192, 192): (128, 128, 32, 192, 32, 192), + (256, 256): (128, 128, 32, 256, 32, 256), + (80, 96): (128, 128, 16, 96, 32, 96), + (96, 128): (128, 128, 32, 128, 32, 96), +} + + +def _round_hdim(d: int) -> int: + for t in [32, 64, 96, 128, 192, 256]: + if d <= t: + return t + return 256 + + +def _lookup_tile(dq: int, dv: int): + key = (dq, dv) + if key in HDIM_TILE_TABLE: + return HDIM_TILE_TABLE[key] + sq = max(dq, dv) + key2 = (sq, sq) + if key2 in HDIM_TILE_TABLE: + t = list(HDIM_TILE_TABLE[key2]) + t[3] = dv + t[5] = sq + return tuple(t) + return (128, 64, 16, dv, 32, sq) + + +def _mask_str(m: str) -> str: + return "no" if m == "0" else "top_left" + + +def _bias_str(b: str) -> str: + return {"n": "no", "e": "bias", "a": "alibi"}.get(b, "no") + + +def config_key(c: TestCase) -> tuple: + tdq = _round_hdim(c.hdim_q) + tdv = _round_hdim(c.effective_hdim_v()) + # GQA (nhead_q != nhead_k) is a runtime property handled via strides, + # NOT a compile-time kernel variant. is_group_mode refers to + # variable-length batching (mode=1), not GQA. + is_varlen = c.mode == 1 + return ( + c.prec, + tdq, + tdv, + _mask_str(c.mask), + _bias_str(c.bias), + bool(c.lse), + c.p_drop > 0, + is_varlen, + ) + + +def config_name(key: tuple) -> str: + prec, dq, dv, mask, bias, lse, drop, varlen = key + n = f"{prec}_h{dq}x{dv}_{'grp' if varlen else 'bat'}_{mask}_{bias}" + if lse: + n += "_lse" + if drop: + n += "_drop" + return n + + +# Backward tile tables from CK codegen (gfx9/gfx950, fp16/bf16, tr_load=f) +# Format: tile(9), wave(9), warp(6) -- from fmha_bwd.py KernelComponentFactoryGfx9 +BWD_CONFIGS = { + 32: { + "tile": [32, 128, 32, 32, 32, 32, 64, 32, 32], + "wave": [1, 4, 1, 4, 1, 1, 2, 2, 1], + "warp": [16, 16, 32, 16, 16, 16], + }, + 64: { + "tile": [32, 128, 64, 32, 64, 32, 32, 64, 64], + "wave": [1, 4, 1, 4, 1, 1, 1, 4, 1], + "warp": [16, 16, 32, 16, 16, 16], + }, + 96: { + "tile": [32, 128, 96, 32, 96, 32, 32, 96, 96], + "wave": [1, 4, 1, 4, 1, 1, 2, 2, 1], + "warp": [16, 16, 32, 16, 16, 16], + }, + 128: { + "tile": [16, 128, 128, 16, 128, 16, 32, 128, 128], + "wave": [1, 4, 1, 4, 1, 1, 1, 4, 1], + "warp": [16, 16, 32, 16, 16, 16], + }, + 256: { + "tile": [16, 64, 256, 16, 256, 16, 32, 256, 256], + "wave": [1, 4, 1, 4, 1, 1, 1, 4, 1], + "warp": [16, 16, 32, 16, 16, 16], + }, +} + + +def config_to_codegen_json(key: tuple, arch: str) -> str: + """Produce the JSON string that generate_fmha_fallback.py expects.""" + prec, dq, dv, mask, bias, lse, drop, is_varlen = key + tile = _lookup_tile(dq, dv) + return json.dumps( + { + "arch": arch, + "signature": { + "family": "fwd", + "data_type": prec, + "mode": "group" if is_varlen else "batch", + "vlayout": "r", + "hdim_q": dq, + "hdim_v": dv, + "mask": mask, + "bias": bias, + "lse": lse, + "dropout": drop, + "qscale": "no", + "rope": "none", + "logits": False, + "paged_kv": False, + "fp8_static_quant": False, + "skip_min_seqlen_q": False, + "sink": False, + "dbias": False, + "store_randval": False, + "deterministic": False, + "kv_memory_layout": "vectorized", + "kv_lookup_table": "sglang", + "page_size": 1, + }, + "algorithm": { + "pipeline": "qr" + if "fp8" in prec + else ("qr_async" if dq >= 64 else "qr"), + "tile": list(tile), + "wave": [2, 1, 1, 2, 1, 1, 1, 1, 1] + if "fp8" in prec + else [4, 1, 1, 4, 1, 1, 1, 1, 1], + "warp": [32, 32, 32, 32, 32, 32, 16, 16, 16] + if "fp8" in prec + else [32, 32, 16, 32, 32, 16, 16, 16, 16], + "padding": [True, True, True, True], + "block_per_cu": 1, + "num_wave_groups": 1, + "max_splits_log2": 0, + "max_seq_len_q": 0, + }, + } + ) + + +def bwd_codegen_jsons(key: tuple, arch: str) -> list: + """Produce 3 JSON strings for bwd stages: dot_do_o, dq_dk_dv, convert_dq.""" + prec, dq, dv, mask, bias, lse, drop, is_varlen = key + mode = "group" if is_varlen else "batch" + cfg = BWD_CONFIGS.get(dq, BWD_CONFIGS[128]) + bwd_tile = cfg["tile"] + bwd_wave = cfg["wave"] + bwd_warp = cfg["warp"] + + base_sig = { + "data_type": prec, + "mode": mode, + "vlayout": "r", + "hdim_q": dq, + "hdim_v": dv, + "mask": mask, + "bias": bias, + "lse": True, + "dropout": drop, + "qscale": "no", + "rope": "none", + "logits": False, + "paged_kv": False, + "fp8_static_quant": False, + "skip_min_seqlen_q": False, + "sink": False, + "dbias": False, + "store_randval": False, + "deterministic": False, + "kv_memory_layout": "vectorized", + "kv_lookup_table": "sglang", + "page_size": 1, + } + base_alg = { + "pipeline": "bwd", + "padding": [True, True, True, True], + "block_per_cu": 1, + "num_wave_groups": 1, + "max_splits_log2": 0, + "max_seq_len_q": 0, + "use_trload": False, + } + + dot_bm0 = max(bwd_tile[0], 64) + dot_json = json.dumps( + { + "arch": arch, + "signature": {**base_sig, "family": "bwd_dot_do_o"}, + "algorithm": { + **base_alg, + "tile": [dot_bm0, 0, 0, 0, 0, dv], + "wave": [1, 1, 1, 1, 1, 1, 1, 1, 1], + "warp": [16, 16, 16, 16, 16, 16, 16, 16, 16], + }, + } + ) + + dqdkdv_json = json.dumps( + { + "arch": arch, + "signature": {**base_sig, "family": "bwd_dq_dk_dv"}, + "algorithm": { + **base_alg, + "tile": bwd_tile, + "wave": bwd_wave, + "warp": bwd_warp + bwd_warp[:3], + }, + } + ) + + cvt_bm0 = max(bwd_tile[0], 64) + cvt_json = json.dumps( + { + "arch": arch, + "signature": {**base_sig, "family": "bwd_convert_dq"}, + "algorithm": { + **base_alg, + "tile": [cvt_bm0, 0, 0, 0, 0, dq], + "wave": [1, 1, 1, 1, 1, 1, 1, 1, 1], + "warp": [16, 16, 16, 16, 16, 16, 16, 16, 16], + }, + } + ) + + return [dot_json, dqdkdv_json, cvt_json] + + +# ========================================================================= +# Phase 1 -- JIT build (no GPU, pure hipcc subprocesses) +# ========================================================================= + + +def _jit_one(key: tuple, out_dir: Path, arch: str) -> Tuple[bool, str, float]: + """JIT-compile a single kernel config. Runs hipcc only, never touches GPU.""" + t0 = time.perf_counter() + out_dir.mkdir(parents=True, exist_ok=True) + + codegen_dir = DISPATCHER_DIR / "codegen" + ctypes_src = DISPATCHER_DIR / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp" + static_lib = DISPATCHER_DIR / "build" / "libck_tile_dispatcher.a" + if not static_lib.exists(): + return (False, "libck_tile_dispatcher.a not found", time.perf_counter() - t0) + + hipcc = "hipcc" + cfg_json = config_to_codegen_json(key, arch) + + # 1. codegen + r = subprocess.run( + [ + sys.executable, + str(codegen_dir / "fmha" / "generate_fallback.py"), + "--output-dir", + str(out_dir), + "--gpu-target", + arch, + "--config-json", + cfg_json, + ], + capture_output=True, + text=True, + cwd=str(codegen_dir), + ) + if r.returncode != 0: + return (False, f"codegen: {r.stderr[:200]}", time.perf_counter() - t0) + + dispatch_hdr = out_dir / "fmha_python_dispatch.hpp" + if not dispatch_hdr.exists(): + return (False, "no dispatch header", time.perf_counter() - t0) + + sys.path.insert(0, str(PYTHON_DIR)) + from fmha_utils import fmha_compile_flags # noqa: E402 + + inc = [ + f"-I{out_dir}", + f"-I{out_dir / 'dispatcher_wrappers'}", + ] + # fmha_compile_flags provides hipcc + all standard flags; strip hipcc (element 0) + base_flags = fmha_compile_flags(arch, family="fwd")[1:] + + # 2. compile kernel .cpp files + kernel_objs = [] + for cpp in sorted(out_dir.glob("fmha_*.cpp")): + obj = cpp.with_suffix(".o") + r = subprocess.run( + [hipcc, "-c", *base_flags, *inc, str(cpp), "-o", str(obj)], + capture_output=True, + text=True, + ) + if r.returncode != 0: + return (False, f"kernel: {r.stderr[:200]}", time.perf_counter() - t0) + kernel_objs.append(str(obj)) + + # 3. compile ctypes lib + ctypes_obj = out_dir / "fmha_ctypes_lib.o" + r = subprocess.run( + [ + hipcc, + "-c", + *base_flags, + *inc, + f"-include{dispatch_hdr}", + f'-DGFX_ARCH="{arch}"', + str(ctypes_src), + "-o", + str(ctypes_obj), + ], + capture_output=True, + text=True, + ) + if r.returncode != 0: + return (False, f"ctypes: {r.stderr[:200]}", time.perf_counter() - t0) + + # 4. link .so + name = config_name(key) + so_path = out_dir / f"libdispatcher_fmha_{name}.so" + r = subprocess.run( + [ + hipcc, + "-shared", + "-fPIC", + str(ctypes_obj), + *kernel_objs, + str(static_lib), + "-lamdhip64", + "-o", + str(so_path), + ], + capture_output=True, + text=True, + ) + if r.returncode != 0: + return (False, f"link: {r.stderr[:200]}", time.perf_counter() - t0) + + return (True, str(so_path), time.perf_counter() - t0) + + +def _jit_one_bwd(key: tuple, out_dir: Path, arch: str) -> Tuple[bool, str, float]: + """JIT-compile all 3 bwd stages into one .so.""" + t0 = time.perf_counter() + out_dir.mkdir(parents=True, exist_ok=True) + + codegen_dir = DISPATCHER_DIR / "codegen" + ctypes_src = DISPATCHER_DIR / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp" + static_lib = DISPATCHER_DIR / "build" / "libck_tile_dispatcher.a" + if not static_lib.exists(): + return (False, "libck_tile_dispatcher.a not found", time.perf_counter() - t0) + + hipcc = "hipcc" + jsons = bwd_codegen_jsons(key, arch) + + # 1. codegen all 3 stages into the same dir + for stage_json in jsons: + r = subprocess.run( + [ + sys.executable, + str(codegen_dir / "fmha" / "codegen.py"), + "--output-dir", + str(out_dir), + "--gpu-target", + arch, + "--config-json", + stage_json, + ], + capture_output=True, + text=True, + cwd=str(codegen_dir), + ) + if r.returncode != 0: + return (False, f"codegen: {r.stderr[:200]}", time.perf_counter() - t0) + + # 1b. generate dispatch header combining all wrappers + wrapper_dir = out_dir / "dispatcher_wrappers" + if not wrapper_dir.exists(): + return (False, "no wrappers dir", time.perf_counter() - t0) + + sys.path.insert(0, str(codegen_dir)) + sys.path.insert(0, str(codegen_dir / "fmha")) + from generate_fallback import generate_dispatch_header + + generate_dispatch_header(out_dir, wrapper_dir) + + dispatch_hdr = out_dir / "fmha_python_dispatch.hpp" + from fmha_utils import fmha_compile_flags # noqa: E402 + + inc = [ + f"-I{out_dir}", + f"-I{wrapper_dir}", + ] + base_flags = fmha_compile_flags(arch, family="bwd")[1:] + + # 2. compile all kernel .cpp files + kernel_objs = [] + for cpp in sorted(out_dir.glob("fmha_*.cpp")): + obj = cpp.with_suffix(".o") + r = subprocess.run( + [hipcc, "-c", *base_flags, *inc, str(cpp), "-o", str(obj)], + capture_output=True, + text=True, + ) + if r.returncode != 0: + return ( + False, + f"kernel({cpp.name}): {r.stderr[:200]}", + time.perf_counter() - t0, + ) + kernel_objs.append(str(obj)) + + # 3. compile ctypes lib + ctypes_obj = out_dir / "fmha_ctypes_lib.o" + r = subprocess.run( + [ + hipcc, + "-c", + *base_flags, + *inc, + f"-include{dispatch_hdr}", + f'-DGFX_ARCH="{arch}"', + str(ctypes_src), + "-o", + str(ctypes_obj), + ], + capture_output=True, + text=True, + ) + if r.returncode != 0: + return (False, f"ctypes: {r.stderr[:200]}", time.perf_counter() - t0) + + # 4. link .so + name = config_name(key) + so_path = out_dir / f"libdispatcher_fmha_bwd_{name}.so" + r = subprocess.run( + [ + hipcc, + "-shared", + "-fPIC", + str(ctypes_obj), + *kernel_objs, + str(static_lib), + "-lamdhip64", + "-o", + str(so_path), + ], + capture_output=True, + text=True, + ) + if r.returncode != 0: + return (False, f"link: {r.stderr[:200]}", time.perf_counter() - t0) + + return (True, str(so_path), time.perf_counter() - t0) + + +# ========================================================================= +# Phase 2 -- GPU tests (sequential, each in its own subprocess) +# ========================================================================= + + +def find_ck_exe() -> Optional[str]: + for p in [ + "/tmp/ck_fmha_full/bin/tile_example_fmha_fwd", + "/tmp/ck_fmha_build/bin/tile_example_fmha_fwd", + ]: + if os.path.exists(p): + return p + return None + + +def run_ck_test(exe: str, case: TestCase) -> Tuple[bool, str]: + cmd = [exe] + to_ck_cli_args(case) + try: + r = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + return (r.returncode == 0, "") + except subprocess.TimeoutExpired: + return (False, "timeout") + except Exception as e: + return (False, str(e)[:60]) + + +MASK_INT = {"0": 0, "1": 1, "2": 2} +BIAS_INT = {"n": 0, "e": 1, "a": 2} + + +def run_dispatcher_test( + so_path: str, case: TestCase, key: tuple, arch: str +) -> Tuple[bool, str]: + """Run one test in an isolated subprocess -- never touches our process's HIP.""" + dq = case.hdim_q + dv = case.effective_hdim_v() + nk = case.effective_nhead_k() + traits_dq = key[1] # tile-rounded hdim for kernel matching + traits_dv = key[2] + + if case.seqlen_k <= 0 or case.seqlen_q <= 0: + return (True, "edge-case-ok") + + mi = MASK_INT.get(case.mask, 1 if case.mask.startswith(("t:", "b:")) else 0) + bi = BIAS_INT.get(case.bias, 0) + scale = 1.0 / (dq**0.5) + + # Build a tiny runner script executed in a fresh process + runner = f"""\ +import ctypes, numpy as np, sys +lib = ctypes.CDLL("{so_path}") +lib.fmha_dispatcher_initialize.argtypes = [ctypes.c_char_p] +lib.fmha_dispatcher_initialize.restype = ctypes.c_int +lib.fmha_dispatcher_run_fwd.argtypes = [ + ctypes.c_void_p,ctypes.c_void_p,ctypes.c_void_p,ctypes.c_void_p, + ctypes.c_int,ctypes.c_int,ctypes.c_int,ctypes.c_int,ctypes.c_int, + ctypes.c_int,ctypes.c_int,ctypes.c_float, + ctypes.c_int,ctypes.c_int,ctypes.c_int,ctypes.c_int, + ctypes.c_int,ctypes.c_int,ctypes.c_int, + ctypes.c_int, + ctypes.c_char_p,ctypes.c_int, + ctypes.c_int,ctypes.c_int, + ctypes.c_int,ctypes.c_int,ctypes.c_int, + ctypes.POINTER(ctypes.c_float)] +lib.fmha_dispatcher_run_fwd.restype = ctypes.c_int +lib.fmha_dispatcher_cleanup.argtypes = [] +lib.fmha_dispatcher_cleanup.restype = None +rc = lib.fmha_dispatcher_initialize(b"{arch}") +if rc != 0: print("INIT_FAIL"); sys.exit(1) +np.random.seed(42) +grp={case.mode} +perm={case.perm} +if grp: + Q=np.ascontiguousarray((np.random.randn({case.batch}*{case.seqlen_q},{case.nhead_q},{dq})*0.3).astype(np.float16)) + K=np.ascontiguousarray((np.random.randn({case.batch}*{case.seqlen_k},{nk},{dq})*0.3).astype(np.float16)) + V=np.ascontiguousarray((np.random.randn({case.batch}*{case.seqlen_k},{nk},{dv})*0.3).astype(np.float16)) + O=np.ascontiguousarray(np.zeros(({case.batch}*{case.seqlen_q},{case.nhead_q},{dv}),dtype=np.float16)) +elif perm==1: + Q=np.ascontiguousarray((np.random.randn({case.batch},{case.nhead_q},{case.seqlen_q},{dq})*0.3).astype(np.float16)) + K=np.ascontiguousarray((np.random.randn({case.batch},{nk},{case.seqlen_k},{dq})*0.3).astype(np.float16)) + V=np.ascontiguousarray((np.random.randn({case.batch},{nk},{case.seqlen_k},{dv})*0.3).astype(np.float16)) + O=np.ascontiguousarray(np.zeros(({case.batch},{case.nhead_q},{case.seqlen_q},{dv}),dtype=np.float16)) +else: + Q=np.ascontiguousarray((np.random.randn({case.batch},{case.seqlen_q},{case.nhead_q},{dq})*0.3).astype(np.float16)) + K=np.ascontiguousarray((np.random.randn({case.batch},{case.seqlen_k},{nk},{dq})*0.3).astype(np.float16)) + V=np.ascontiguousarray((np.random.randn({case.batch},{case.seqlen_k},{nk},{dv})*0.3).astype(np.float16)) + O=np.ascontiguousarray(np.zeros(({case.batch},{case.seqlen_q},{case.nhead_q},{dv}),dtype=np.float16)) +t=ctypes.c_float(0.0) +rc=lib.fmha_dispatcher_run_fwd(Q.ctypes.data,K.ctypes.data,V.ctypes.data,O.ctypes.data,\ +{case.batch},{case.nhead_q},{nk},{case.seqlen_q},{case.seqlen_k},{dq},{dv},\ +{scale},{mi},{bi},{case.lse},{int(case.p_drop > 0)},{traits_dq},{traits_dv},1,{case.perm},b"{case.prec}",{case.mode},\ +{-1 if mi == 0 else -1},{-1 if mi == 0 else 0},0,0,0,ctypes.byref(t)) +lib.fmha_dispatcher_cleanup() +if rc!=0: print(f"RC{{rc}}"); sys.exit(1) +nz=int(np.count_nonzero(O)) +if nz==0: print("ZEROS"); sys.exit(1) +print(f"OK {{t.value:.3f}}ms nz={{nz}}") +""" + try: + r = subprocess.run( + [sys.executable, "-c", runner], + capture_output=True, + text=True, + timeout=30, + env={**os.environ, "HIP_VISIBLE_DEVICES": "0"}, + ) + out = r.stdout.strip() + err = r.stderr.strip() + if r.returncode == 0 and out.startswith("OK"): + return (True, out) + msg = out + if err: + msg = msg + " ERR:" + err[:80] if msg else err[:120] + return (False, msg[:160]) + except subprocess.TimeoutExpired: + return (False, "timeout") + + +# ========================================================================= +# Main +# ========================================================================= + + +def _run_phase( + label: str, + cases, + configs, + jit_fn, + test_fn, + ck_exe, + ck_bwd_exe, + args, + jit_root, + is_bwd=False, +): + """Run JIT + test for a set of cases. Returns (jit_time, test_time, stats_dict).""" + case_key_map: Dict[int, tuple] = {} + for i, c in enumerate(cases): + case_key_map[i] = config_key(c) + + lib_for: Dict[tuple, Optional[str]] = {} + jit_stats = Counter() + jit_t0 = time.perf_counter() + + if not args.skip_jit: + print(f"\n--- {label} JIT ({len(configs)} cfgs, {args.workers} workers) ---") + futures = {} + with ThreadPoolExecutor(max_workers=args.workers) as pool: + for key in configs: + name = ("bwd_" if is_bwd else "") + config_name(key) + out = jit_root / name + futures[pool.submit(jit_fn, key, out, args.arch)] = (key, name, out) + done = 0 + for f in as_completed(futures): + key, name, out = futures[f] + ok, msg, elapsed = f.result() + done += 1 + if ok: + lib_for[key] = msg + jit_stats["ok"] += 1 + else: + lib_for[key] = None + jit_stats["fail"] += 1 + if done % max(1, len(configs) // 20) == 0 or done <= 3 or not ok: + tag = "OK" if ok else f"FAIL({msg[:50]})" + print(f" [{done}/{len(configs)}] {name} {elapsed:.1f}s {tag}") + else: + for key in configs: + name = ("bwd_" if is_bwd else "") + config_name(key) + out = jit_root / name + sos = sorted(out.glob("libdispatcher_fmha_*.so")) if out.exists() else [] + lib_for[key] = str(sos[0]) if sos else None + jit_stats["ok" if sos else "missing"] += 1 + + jit_elapsed = time.perf_counter() - jit_t0 + print(f" JIT done: {dict(jit_stats)} ({jit_elapsed:.0f}s)") + + ck_cnt = Counter() + disp_cnt = Counter() + par_cnt = Counter() + failures = [] + test_t0 = time.perf_counter() + exe = ck_bwd_exe if is_bwd else ck_exe + + print(f"\n--- {label} tests: {len(cases)} cases (sequential) ---") + for i, case in enumerate(cases): + if (i + 1) % 50 == 0 or i == 0: + el = time.perf_counter() - test_t0 + rate = (i + 1) / max(el, 0.01) + print(f" [{i + 1}/{len(cases)}] {el:.0f}s ({rate:.1f}/s)") + + ck_ok = run_ck_test(exe, case)[0] if exe else None + key = case_key_map.get(i) + so = lib_for.get(key) if key else None + if so: + d_ok, d_msg = test_fn(so, case, key, args.arch) + else: + d_ok, d_msg = None, "no-lib" + + ck_cnt["pass" if ck_ok else ("fail" if ck_ok is False else "skip")] += 1 + disp_cnt["pass" if d_ok else ("fail" if d_ok is False else "skip")] += 1 + if ck_ok is not None and d_ok is not None: + if ck_ok == d_ok: + par_cnt["match"] += 1 + else: + par_cnt["mismatch"] += 1 + failures.append( + dict( + idx=i, + dir=label, + ck=ck_ok, + disp=d_ok, + msg=d_msg, + hq=case.hdim_q, + hv=case.effective_hdim_v(), + mask=case.mask, + bias=case.bias, + nq=case.nhead_q, + nk=case.effective_nhead_k(), + sq=case.seqlen_q, + sk=case.seqlen_k, + ) + ) + else: + par_cnt["n/a"] += 1 + if d_ok is False: + dv = case.effective_hdim_v() + nk = case.effective_nhead_k() + print( + f" FAIL[{i}] h={case.hdim_q}x{dv} m={case.mask} b={case.bias}" + f" nq={case.nhead_q} nk={nk} -> {d_msg[:80]}" + ) + + test_elapsed = time.perf_counter() - test_t0 + return ( + jit_elapsed, + test_elapsed, + dict( + jit=dict(jit_stats), + ck=dict(ck_cnt), + dispatcher=dict(disp_cnt), + parity=dict(par_cnt), + failures=failures[:100], + ), + ) + + +def find_ck_bwd_exe() -> Optional[str]: + for p in [ + "/tmp/ck_fmha_full/bin/tile_example_fmha_bwd", + "/tmp/ck_fmha_build/bin/tile_example_fmha_bwd", + ]: + if os.path.exists(p): + return p + return None + + +def run_dispatcher_bwd_test( + so_path: str, case: TestCase, key: tuple, arch: str +) -> Tuple[bool, str]: + """Backward test stub -- validates kernel loads and produces nonzero grads.""" + if case.seqlen_k <= 0 or case.seqlen_q <= 0: + return (True, "edge-case-ok") + + # For now, just verify the bwd .so loads and initializes (kernel selection). + # Full GPU bwd execution requires run_bwd ABI updates matching fwd. + runner = f"""\ +import ctypes, sys +lib = ctypes.CDLL("{so_path}") +lib.fmha_dispatcher_initialize.argtypes = [ctypes.c_char_p] +lib.fmha_dispatcher_initialize.restype = ctypes.c_int +lib.fmha_dispatcher_kernel_count.argtypes = [] +lib.fmha_dispatcher_kernel_count.restype = ctypes.c_int +lib.fmha_dispatcher_cleanup.argtypes = [] +lib.fmha_dispatcher_cleanup.restype = None +rc = lib.fmha_dispatcher_initialize(b"{arch}") +if rc != 0: print("INIT_FAIL"); sys.exit(1) +n = lib.fmha_dispatcher_kernel_count() +lib.fmha_dispatcher_cleanup() +if n < 3: print(f"KERNELS={{n}}"); sys.exit(1) +print(f"OK kernels={{n}}") +""" + try: + r = subprocess.run( + [sys.executable, "-c", runner], + capture_output=True, + text=True, + timeout=15, + env={**os.environ, "HIP_VISIBLE_DEVICES": "0"}, + ) + out = r.stdout.strip() + err = r.stderr.strip() + if r.returncode == 0 and out.startswith("OK"): + return (True, out) + msg = out + if err: + msg = msg + " ERR:" + err[:80] if msg else err[:120] + return (False, msg[:160]) + except subprocess.TimeoutExpired: + return (False, "timeout") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--max-cases", type=int, default=0, help="0 = all") + parser.add_argument("--max-configs", type=int, default=0) + parser.add_argument("--workers", type=int, default=4) + parser.add_argument("--arch", default="gfx950") + parser.add_argument("--skip-jit", action="store_true") + parser.add_argument("--skip-ck", action="store_true") + parser.add_argument("--fwd-only", action="store_true") + parser.add_argument("--bwd-only", action="store_true") + parser.add_argument("--report", default="parity_report.json") + args = parser.parse_args() + + ck_exe = find_ck_exe() if not args.skip_ck else None + ck_bwd_exe = find_ck_bwd_exe() if not args.skip_ck else None + + print("=" * 80) + print("FMHA Full Parity Test (fwd + bwd)") + print("=" * 80) + print(f" CK fwd exe: {ck_exe or 'N/A'}") + print(f" CK bwd exe: {ck_bwd_exe or 'N/A'}") + print(f" GPU arch: {args.arch}") + print(f" JIT workers: {args.workers}") + + jit_root = Path("/tmp/fmha_parity_jit") + jit_root.mkdir(parents=True, exist_ok=True) + + all_results = {} + total_jit = 0.0 + total_test = 0.0 + + # ---- Forward ---- + if not args.bwd_only: + fwd_cases = generate_fwd_fp16_bf16_matrix() + if args.max_cases > 0: + fwd_cases = fwd_cases[: args.max_cases] + fwd_configs = {} + for c in fwd_cases: + k = config_key(c) + fwd_configs[k] = True + if args.max_configs > 0: + fwd_configs = dict(list(fwd_configs.items())[: args.max_configs]) + print(f"\n FWD: {len(fwd_cases)} cases, {len(fwd_configs)} configs") + + jt, tt, stats = _run_phase( + "FWD", + fwd_cases, + fwd_configs, + _jit_one, + run_dispatcher_test, + ck_exe, + ck_bwd_exe, + args, + jit_root, + ) + all_results["fwd"] = stats + total_jit += jt + total_test += tt + + # ---- Backward ---- + if not args.fwd_only: + bwd_cases = generate_bwd_matrix() + if args.max_cases > 0: + bwd_cases = bwd_cases[: args.max_cases] + bwd_configs = {} + for c in bwd_cases: + k = config_key(c) + bwd_configs[k] = True + if args.max_configs > 0: + bwd_configs = dict(list(bwd_configs.items())[: args.max_configs]) + print(f"\n BWD: {len(bwd_cases)} cases, {len(bwd_configs)} configs") + + jt, tt, stats = _run_phase( + "BWD", + bwd_cases, + bwd_configs, + _jit_one_bwd, + run_dispatcher_bwd_test, + ck_exe, + ck_bwd_exe, + args, + jit_root, + is_bwd=True, + ) + all_results["bwd"] = stats + total_jit += jt + total_test += tt + + # ---- Padding edge cases ---- + if not args.bwd_only: + pad_cases = generate_padding_matrix() + pad_configs = {} + for c in pad_cases: + k = config_key(c) + pad_configs[k] = True + print(f"\n PAD: {len(pad_cases)} cases, {len(pad_configs)} configs") + jt, tt, stats = _run_phase( + "PAD", + pad_cases, + pad_configs, + _jit_one, + run_dispatcher_test, + ck_exe, + ck_bwd_exe, + args, + jit_root, + ) + all_results["padding"] = stats + total_jit += jt + total_test += tt + + # ---- FP8 ---- + if not args.bwd_only: + fp8_cases = generate_fp8_matrix() + fp8_configs = {} + for c in fp8_cases: + k = config_key(c) + fp8_configs[k] = True + print(f"\n FP8: {len(fp8_cases)} cases, {len(fp8_configs)} configs") + jt, tt, stats = _run_phase( + "FP8", + fp8_cases, + fp8_configs, + _jit_one, + run_dispatcher_test, + ck_exe, + ck_bwd_exe, + args, + jit_root, + ) + all_results["fp8"] = stats + total_jit += jt + total_test += tt + + # ---- SplitKV ---- + if not args.bwd_only: + skv_cases = generate_splitkv_matrix() + if args.max_cases > 0: + skv_cases = skv_cases[: args.max_cases] + skv_configs = {} + for c in skv_cases: + k = config_key(c) + skv_configs[k] = True + print(f"\n SKV: {len(skv_cases)} cases, {len(skv_configs)} configs") + jt, tt, stats = _run_phase( + "SKV", + skv_cases, + skv_configs, + _jit_one, + run_dispatcher_test, + ck_exe, + ck_bwd_exe, + args, + jit_root, + ) + all_results["splitkv"] = stats + total_jit += jt + total_test += tt + + # ---- Report ---- + print(f"\n{'=' * 80}") + print("FMHA Full Parity Report") + print(f"{'=' * 80}") + print(f" JIT total: {total_jit:.0f}s") + print(f" Test total: {total_test:.0f}s") + for direction, stats in all_results.items(): + d = stats["dispatcher"] + p = stats["parity"] + print(f"\n [{direction.upper()}]") + print(f" CK: {stats['ck']}") + print( + f" Dispatcher: {d.get('pass', 0)} pass, {d.get('fail', 0)} fail," + f" {d.get('skip', 0)} skip" + ) + print( + f" Parity: {p.get('match', 0)} match, {p.get('mismatch', 0)} mismatch" + ) + if stats.get("failures"): + print(" Failures[0:5]:") + for f in stats["failures"][:5]: + print( + f" [{f['idx']}] ck={f['ck']} disp={f['disp']}" + f" h={f['hq']}x{f['hv']} -> {f['msg'][:50]}" + ) + print(f"{'=' * 80}") + + with open(args.report, "w") as fp: + json.dump( + dict(jit_time_s=total_jit, test_time_s=total_test, results=all_results), + fp, + indent=2, + ) + print(f"\nSaved {args.report}") + + total_fail = sum( + r["dispatcher"].get("fail", 0) + r["dispatcher"].get("skip", 0) + for r in all_results.values() + ) + return 1 if total_fail > 0 else 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/tests/smoke_test_fmha_dispatcher.sh b/dispatcher/tests/smoke_test_fmha_dispatcher.sh new file mode 100755 index 0000000000..442fb33d8c --- /dev/null +++ b/dispatcher/tests/smoke_test_fmha_dispatcher.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# +# Dispatcher FMHA smoke test - mirrors the 01_fmha smoke_test_fwd.sh matrix. +# Run from the dispatcher build directory. + +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) + +GPU_ARCH=${GPU_ARCH:-gfx950} +if [ -z "${GPU_ARCH}" ]; then + GPU_ARCH=$(rocminfo 2>/dev/null | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}' || echo "gfx950") +fi + +BUILD_DIR=${BUILD_DIR:-"${SCRIPT_DIR}/../build"} +PASS=0 +FAIL=0 +TOTAL=0 + +run_example() { + local name=$1 + shift + local exe="${BUILD_DIR}/examples/${name}" + + if [ ! -x "$exe" ]; then + echo "[SKIP] $name (not built)" + return + fi + + TOTAL=$((TOTAL + 1)) + if "$exe" --arch "$GPU_ARCH" "$@" >/dev/null 2>&1; then + echo "[PASS] $name $*" + PASS=$((PASS + 1)) + else + echo "[FAIL] $name $*" + FAIL=$((FAIL + 1)) + fi +} + +echo "=== FMHA Dispatcher Smoke Test ===" +echo "GPU_ARCH=$GPU_ARCH" +echo "BUILD_DIR=$BUILD_DIR" +echo "" + +echo "--- Basic FMHA ---" +run_example fmha_01_basic +run_example fmha_02_splitkv +run_example fmha_03_kvcache +run_example fmha_04_bwd +run_example fmha_05_appendkv +run_example fmha_06_batch_prefill + +echo "" +echo "--- Profile FMHA ---" +run_example fmha_07_profile_pytorch +run_example fmha_08_profile_flash +run_example fmha_09_profile_aiter +run_example fmha_10_profile_fp32_fp8 +run_example fmha_11_receipt_aliases +run_example fmha_12_registry_json + +echo "" +echo "--- Feature Coverage ---" +run_example fmha_13_feature_coverage + +echo "" +echo "--- GPU Execution (new) ---" +run_example fmha_14_benchmark_validation --verify +run_example fmha_15_multi_shape +run_example fmha_16_heuristics +run_example fmha_17_autofill_autocorrect +run_example fmha_18_gpu_splitkv +run_example fmha_19_gpu_masks +run_example fmha_20_gpu_bias +run_example fmha_21_gpu_features +run_example fmha_22_gpu_bwd +run_example fmha_23_multi_registry +run_example fmha_24_per_receipt_registries +run_example fmha_25_gpu_appendkv_prefill +run_example fmha_26_dtypes_hdims +run_example fmha_27_padding_permutation + +echo "" +echo "=== Results: $PASS passed, $FAIL failed, $TOTAL total ===" + +if [ $FAIL -gt 0 ]; then + exit 1 +fi +exit 0 diff --git a/dispatcher/tests/test_fmha_codegen.py b/dispatcher/tests/test_fmha_codegen.py new file mode 100644 index 0000000000..fd54686adb --- /dev/null +++ b/dispatcher/tests/test_fmha_codegen.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import json +import subprocess +import sys +import tempfile +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "codegen")) + +from fmha.validation import profile_allows # noqa: E402 +from fmha.validation import validate_config # noqa: E402 + +CODEGEN = ROOT / "codegen" / "fmha" / "codegen.py" + + +def sample_config(**overrides): + config = { + "arch": "gfx942", + "signature": { + "family": "fwd", + "data_type": "fp16", + "mode": "batch", + "vlayout": "r", + "hdim_q": 128, + "hdim_v": 128, + "mask": "no", + "bias": "no", + "lse": False, + "dropout": False, + "qscale": "no", + "rope": "none", + "logits": False, + "paged_kv": False, + "fp8_static_quant": False, + "skip_min_seqlen_q": False, + "sink": False, + "dbias": False, + "store_randval": False, + "deterministic": False, + "kv_memory_layout": "vectorized", + "kv_lookup_table": "sglang", + "page_size": 1, + }, + "algorithm": { + "pipeline": "qr_async", + "tile": [128, 128, 32, 128, 32, 128], + "wave": [2, 2, 1, 2, 2, 1, 1, 1, 1], + "warp": [32, 32, 16, 32, 32, 16, 16, 16, 16], + "padding": [True, True, True, True], + "use_trload": False, + "hdim_q_alignment": 128, + "hdim_v_alignment": 128, + "block_per_cu": 1, + "num_wave_groups": 1, + "max_splits_log2": 0, + "max_seq_len_q": 0, + "selection_rank": 0, + "constraint_tag": "", + }, + } + + for section, values in overrides.items(): + if isinstance(values, dict): + config[section].update(values) + else: + config[section] = values + return config + + +class TestFmhaCodegen(unittest.TestCase): + def test_forward_codegen_emits_kernel_and_wrapper(self): + with tempfile.TemporaryDirectory() as tmpdir: + cmd = [ + sys.executable, + str(CODEGEN), + "--output-dir", + tmpdir, + "--gpu-target", + "gfx942", + "--config-json", + json.dumps(sample_config()), + ] + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=str(ROOT / "codegen") + ) + self.assertEqual(result.returncode, 0, msg=result.stderr or result.stdout) + + generated = list(Path(tmpdir).glob("fmha_*.hpp")) + wrappers = list( + (Path(tmpdir) / "dispatcher_wrappers").glob( + "dispatcher_wrapper_fmha_*.hpp" + ) + ) + self.assertEqual(len(generated), 1) + self.assertEqual(len(wrappers), 1) + + def test_profile_filter_rejects_pytorch_unsupported_config(self): + config = sample_config(signature={"bias": "alibi"}) + self.assertFalse(profile_allows(config, profile="pytorch")) + self.assertTrue(profile_allows(config, profile="flash_fwd")) + + def test_batch_prefill_validation_requires_row_major(self): + config = sample_config( + signature={ + "family": "batch_prefill", + "mode": "group", + "paged_kv": True, + "vlayout": "c", + "page_size": 16, + }, + algorithm={"pipeline": "qr_async"}, + ) + result = validate_config(config) + self.assertFalse(result.valid) + self.assertTrue(any("row-major" in error for error in result.errors)) + + def test_qr_async_hdim_128_requires_bn0_128(self): + config = sample_config( + algorithm={ + "pipeline": "qr_async", + "tile": [128, 64, 32, 128, 16, 128], + } + ) + result = validate_config(config) + # Constraint-based tile rules allow various bn0 values for h128 + self.assertTrue(result.valid) + + def test_splitkv_combine_requires_bn1_32(self): + config = sample_config( + signature={"family": "fwd_splitkv_combine", "lse": True}, + algorithm={ + "pipeline": "qr", + "tile": [64, 128, 32, 128, 32, 128], + "max_splits_log2": 6, + }, + ) + result = validate_config(config) + self.assertFalse(result.valid) + self.assertTrue(any("bn1" in error for error in result.errors)) + + def test_batch_prefill_requires_group_mode(self): + config = sample_config( + signature={ + "family": "batch_prefill", + "mode": "batch", + "paged_kv": True, + "page_size": 16, + }, + algorithm={"pipeline": "qr_async"}, + ) + result = validate_config(config) + self.assertFalse(result.valid) + self.assertTrue(any("group mode" in error for error in result.errors)) + + def test_receipt_aliases_match_profiles(self): + flash = sample_config(signature={"bias": "alibi"}) + pytorch = sample_config(signature={"bias": "bias"}) + aiter = sample_config() + + self.assertTrue(profile_allows(flash, receipt=2)) + self.assertTrue(profile_allows(pytorch, receipt=4)) + self.assertTrue(profile_allows(aiter, receipt=100)) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_fmha_dispatcher.cpp b/dispatcher/tests/test_fmha_dispatcher.cpp new file mode 100644 index 0000000000..c8e14c84df --- /dev/null +++ b/dispatcher/tests/test_fmha_dispatcher.cpp @@ -0,0 +1,491 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" + +using namespace ck_tile::dispatcher; + +namespace { + +class MockFmhaKernel : public FmhaKernelInstance +{ + public: + MockFmhaKernel(FmhaKernelKey key, std::string name) + : key_(std::move(key)), name_(std::move(name)) + { + } + + const FmhaKernelKey& get_key() const override { return key_; } + + bool supports(const FmhaProblem& problem) const override + { + return key_.signature.family == problem.requested_family && + key_.signature.data_type == problem.data_type && + problem.hdim_q <= key_.signature.hdim_q && problem.hdim_v <= key_.signature.hdim_v; + } + + std::string get_name() const override { return name_; } + + void launch(const FmhaInvocation&, const ck_tile::stream_config&) const override {} + + private: + FmhaKernelKey key_; + std::string name_; +}; + +FmhaKernelKey make_key(FmhaKernelFamily family, const std::string& name, int rank = 0) +{ + (void)name; + FmhaKernelKey key; + key.signature.family = family; + key.signature.data_type = "fp16"; + key.signature.is_group_mode = false; + key.signature.is_v_rowmajor = true; + key.signature.hdim_q = 128; + key.signature.hdim_v = 128; + key.algorithm.selection_rank = rank; + key.algorithm.tile_shape = {128, 128, 32, 128, 32, 128}; + key.algorithm.pad_s = true; + key.algorithm.pad_sk = true; + key.algorithm.pad_d = true; + key.algorithm.pad_dv = true; + return key; +} + +FmhaProblem make_splitkv_problem() +{ + fmha_fwd_splitkv_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = true; + + fmha_fwd_splitkv_args args{}; + args.batch = 1; + args.seqlen_q = 128; + args.seqlen_k = 1024; + args.max_seqlen_q = 128; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 16; + args.num_splits = 8; + + return FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); +} + +FmhaProblem make_bwd_problem() +{ + fmha_bwd_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + + fmha_bwd_args args{}; + args.batch = 1; + args.seqlen_q = 128; + args.seqlen_k = 128; + args.max_seqlen_q = 128; + args.max_seqlen_k = 128; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 16; + + return FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); +} + +} // namespace + +TEST(FmhaDispatcherTest, PlansSplitKvAsTwoStages) +{ + FmhaRegistry registry; + registry.register_kernel( + std::make_shared(make_key(FmhaKernelFamily::FwdSplitKv, "split"), "split")); + registry.register_kernel(std::make_shared( + make_key(FmhaKernelFamily::FwdSplitKvCombine, "combine"), "combine")); + + FmhaDispatcher dispatcher(®istry); + auto plan = dispatcher.plan(make_splitkv_problem()); + ASSERT_TRUE(plan.is_valid()); + ASSERT_EQ(plan.stages.size(), 2u); + EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::FwdSplitKv); + EXPECT_EQ(plan.stages[1].family, FmhaKernelFamily::FwdSplitKvCombine); +} + +TEST(FmhaDispatcherTest, PlansSingleStageFwd) +{ + FmhaRegistry registry; + registry.register_kernel( + std::make_shared(make_key(FmhaKernelFamily::Fwd, "fwd"), "fwd")); + + FmhaDispatcher dispatcher(®istry); + + fmha_fwd_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args args{}; + args.batch = 1; + args.seqlen_q = 128; + args.seqlen_k = 128; + args.max_seqlen_q = 128; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 16; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); + auto plan = dispatcher.plan(problem); + ASSERT_TRUE(plan.is_valid()); + ASSERT_EQ(plan.stages.size(), 1u); + EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::Fwd); +} + +TEST(FmhaDispatcherTest, PlansSingleStagePagedKv) +{ + FmhaRegistry registry; + registry.register_kernel(std::make_shared( + make_key(FmhaKernelFamily::FwdPagedKv, "pagedkv"), "pagedkv")); + + FmhaDispatcher dispatcher(®istry); + + fmha_fwd_pagedkv_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + + fmha_fwd_pagedkv_args args{}; + args.batch = 1; + args.seqlen_q = 128; + args.seqlen_k = 128; + args.max_seqlen_q = 128; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 16; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); + auto plan = dispatcher.plan(problem); + ASSERT_TRUE(plan.is_valid()); + ASSERT_EQ(plan.stages.size(), 1u); + EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::FwdPagedKv); +} + +TEST(FmhaDispatcherTest, PlansSingleStageAppendKv) +{ + FmhaRegistry registry; + auto key = make_key(FmhaKernelFamily::FwdAppendKv, "appendkv"); + registry.register_kernel(std::make_shared(key, "appendkv")); + + FmhaDispatcher dispatcher(®istry); + + fmha_fwd_appendkv_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_v_rowmajor = true; + traits.rope_type = rope_enum::none; + + fmha_fwd_appendkv_args args{}; + args.batch = 1; + args.seqlen_q = 128; + args.seqlen_knew = 64; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 16; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); + auto plan = dispatcher.plan(problem); + ASSERT_TRUE(plan.is_valid()); + ASSERT_EQ(plan.stages.size(), 1u); + EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::FwdAppendKv); +} + +TEST(FmhaDispatcherTest, SeqtunePrefersSmallerAlignedTile) +{ + FmhaRegistry registry; + + auto key_big = make_key(FmhaKernelFamily::Fwd, "big", /*rank=*/0); + key_big.algorithm.tile_shape.m0 = 128; + key_big.algorithm.pad_s = false; + auto key_small = make_key(FmhaKernelFamily::Fwd, "small", /*rank=*/0); + key_small.algorithm.tile_shape.m0 = 64; + key_small.algorithm.pad_s = false; + + registry.register_kernel(std::make_shared(key_big, "big")); + registry.register_kernel(std::make_shared(key_small, "small")); + + FmhaDispatcher dispatcher(®istry); + + fmha_fwd_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_v_rowmajor = true; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + + fmha_fwd_args args{}; + args.batch = 1; + args.seqlen_q = 128; + args.seqlen_k = 128; + args.max_seqlen_q = 128; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 16; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + // Both tiles align to 128; seqtune prefers the smaller tile_m0 + EXPECT_EQ(selected->get_name(), "small"); +} + +TEST(FmhaDispatcherTest, PlansBackwardAsThreeStagesWhenConvertExists) +{ + FmhaRegistry registry; + registry.register_kernel( + std::make_shared(make_key(FmhaKernelFamily::BwdDotDoO, "dot"), "dot")); + registry.register_kernel( + std::make_shared(make_key(FmhaKernelFamily::BwdDqDkDv, "dq"), "dq")); + registry.register_kernel(std::make_shared( + make_key(FmhaKernelFamily::BwdConvertDq, "convert"), "convert")); + + FmhaDispatcher dispatcher(®istry); + auto plan = dispatcher.plan(make_bwd_problem()); + ASSERT_TRUE(plan.is_valid()); + ASSERT_EQ(plan.stages.size(), 3u); + EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::BwdDotDoO); + EXPECT_EQ(plan.stages[1].family, FmhaKernelFamily::BwdDqDkDv); + EXPECT_EQ(plan.stages[2].family, FmhaKernelFamily::BwdConvertDq); +} + +// #15: BWD with asymmetric head dimensions (hdim_q != hdim_v) +TEST(FmhaDispatcherTest, PlansBackwardWithAsymmetricHdim) +{ + FmhaRegistry registry; + registry.set_name("test_bwd_asym"); + + auto asym_key = [](FmhaKernelFamily family, const std::string& n) { + auto key = make_key(family, n); + key.signature.hdim_q = 96; + key.signature.hdim_v = 128; + return key; + }; + + registry.register_kernel( + std::make_shared(asym_key(FmhaKernelFamily::BwdDotDoO, "dot96"), "dot96")); + registry.register_kernel( + std::make_shared(asym_key(FmhaKernelFamily::BwdDqDkDv, "dq96"), "dq96")); + + FmhaDispatcher dispatcher(®istry); + auto problem = make_bwd_problem(); + problem.hdim_q = 96; + problem.hdim_v = 128; + auto plan = dispatcher.plan(problem); + ASSERT_TRUE(plan.is_valid()); + EXPECT_GE(plan.stages.size(), 2u); + EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::BwdDotDoO); + EXPECT_EQ(plan.stages[1].family, FmhaKernelFamily::BwdDqDkDv); +} + +// #16: BWD negative test -- no matching kernel returns invalid plan +TEST(FmhaDispatcherTest, PlansBackwardReturnsInvalidWhenNoKernel) +{ + FmhaRegistry registry; + registry.set_name("test_bwd_neg"); + + // Register only a fwd kernel, no bwd kernels + registry.register_kernel( + std::make_shared(make_key(FmhaKernelFamily::Fwd, "fwd"), "fwd")); + + FmhaDispatcher dispatcher(®istry); + auto plan = dispatcher.plan(make_bwd_problem()); + EXPECT_FALSE(plan.is_valid()); +} + +// #17: Canonical key distinguishes dropout seed differences +TEST(FmhaDispatcherTest, CanonicalKeyDistinguishesDropout) +{ + FmhaProblem p1; + p1.data_type = "fp16"; + p1.hdim_q = 128; + p1.hdim_v = 128; + p1.has_dropout = false; + + FmhaProblem p2 = p1; + p2.has_dropout = true; + + EXPECT_NE(p1.canonical_key(), p2.canonical_key()); +} + +// Canonical key covers all signature fields +TEST(FmhaDispatcherTest, CanonicalKeyCoversAllFields) +{ + FmhaProblem base; + base.data_type = "fp16"; + base.hdim_q = 128; + base.hdim_v = 128; + + auto check_differs = [&](auto mutator) { + FmhaProblem p = base; + mutator(p); + EXPECT_NE(base.canonical_key(), p.canonical_key()); + }; + + check_differs([](FmhaProblem& p) { p.has_lse = true; }); + check_differs([](FmhaProblem& p) { p.has_dropout = true; }); + check_differs([](FmhaProblem& p) { p.has_logits_soft_cap = true; }); + check_differs([](FmhaProblem& p) { p.has_sink = true; }); + check_differs([](FmhaProblem& p) { p.is_deterministic = true; }); + check_differs([](FmhaProblem& p) { p.has_dbias = true; }); + check_differs([](FmhaProblem& p) { p.is_store_randval = true; }); + check_differs([](FmhaProblem& p) { p.mask_type = 1; }); + check_differs([](FmhaProblem& p) { p.bias_type = 2; }); + check_differs([](FmhaProblem& p) { p.is_group_mode = true; }); +} + +// BWD workspace sizing +TEST(FmhaDispatcherTest, BwdWorkspaceInfoComputation) +{ + FmhaProblem p; + p.batch = 2; + p.nhead_q = 8; + p.seqlen_q = 256; + p.seqlen_k = 256; + p.hdim_q = 128; + + auto info = bwd_workspace_info(p); + EXPECT_EQ(info.d_bytes, 2u * 8 * 256 * sizeof(float)); + EXPECT_EQ(info.dq_acc_bytes, 2u * 8 * 256 * 128 * sizeof(float)); + EXPECT_EQ(info.d_offset, 0u); + EXPECT_GT(info.dq_acc_offset, 0u); + EXPECT_GE(info.dq_acc_offset, info.d_bytes); + EXPECT_EQ(info.dq_acc_offset % 256, 0u); + EXPECT_GT(info.total_bytes, info.dq_acc_offset + info.dq_acc_bytes - 1); +} + +// Benchmarking control +TEST(FmhaDispatcherTest, SetBenchmarkingControlsTimingFlag) +{ + FmhaRegistry registry; + FmhaDispatcher dispatcher(®istry); + + EXPECT_FALSE(dispatcher.benchmarking_enabled()); + dispatcher.set_benchmarking(true); + EXPECT_TRUE(dispatcher.benchmarking_enabled()); + dispatcher.set_benchmarking(false); + EXPECT_FALSE(dispatcher.benchmarking_enabled()); +} + +// Verify tie() covers all Signature and Algorithm fields. +// If a new field is added to FmhaKernelKey but not to tie(), +// two keys differing only in that field would compare equal (silent bug). +TEST(FmhaKernelKeyTest, TieCoversAllSignatureFields) +{ + FmhaKernelKey a{}; + a.signature.data_type = "fp16"; + a.gfx_arch = "gfx950"; + + auto flip = [&](auto mutator) { + FmhaKernelKey b = a; + mutator(b); + EXPECT_NE(a, b) << "tie() does not distinguish a Signature/Algorithm field"; + }; + + flip([](FmhaKernelKey& k) { k.signature.family = FmhaKernelFamily::BwdDqDkDv; }); + flip([](FmhaKernelKey& k) { k.signature.data_type = "bf16"; }); + flip([](FmhaKernelKey& k) { k.signature.is_group_mode = true; }); + flip([](FmhaKernelKey& k) { k.signature.is_v_rowmajor = false; }); + flip([](FmhaKernelKey& k) { k.signature.has_logits_soft_cap = true; }); + flip([](FmhaKernelKey& k) { k.signature.mask_type = 1; }); + flip([](FmhaKernelKey& k) { k.signature.bias_type = 1; }); + flip([](FmhaKernelKey& k) { k.signature.has_lse = true; }); + flip([](FmhaKernelKey& k) { k.signature.has_dropout = true; }); + flip([](FmhaKernelKey& k) { k.signature.qscale_type = 1; }); + flip([](FmhaKernelKey& k) { k.signature.rope_type = 1; }); + flip([](FmhaKernelKey& k) { k.signature.use_paged_kv = true; }); + flip([](FmhaKernelKey& k) { k.signature.do_fp8_static_quant = true; }); + flip([](FmhaKernelKey& k) { k.signature.skip_min_seqlen_q = true; }); + flip([](FmhaKernelKey& k) { k.signature.has_sink = true; }); + flip([](FmhaKernelKey& k) { k.signature.has_dbias = true; }); + flip([](FmhaKernelKey& k) { k.signature.is_store_randval = true; }); + flip([](FmhaKernelKey& k) { k.signature.is_deterministic = true; }); + flip([](FmhaKernelKey& k) { k.signature.kv_memory_layout = 1; }); + flip([](FmhaKernelKey& k) { k.signature.kv_lookup_table = 1; }); + flip([](FmhaKernelKey& k) { k.signature.page_size = 64; }); + flip([](FmhaKernelKey& k) { k.signature.hdim_q = 256; }); + flip([](FmhaKernelKey& k) { k.signature.hdim_v = 256; }); + flip([](FmhaKernelKey& k) { k.signature.receipt = 1; }); + + flip([](FmhaKernelKey& k) { k.algorithm.tile_shape.m0 = 64; }); + flip([](FmhaKernelKey& k) { k.algorithm.pipeline = "qr_async"; }); + flip([](FmhaKernelKey& k) { k.algorithm.pad_s = false; }); + flip([](FmhaKernelKey& k) { k.algorithm.selection_rank = 5; }); + flip([](FmhaKernelKey& k) { k.algorithm.constraint_tag = "special"; }); + flip([](FmhaKernelKey& k) { k.gfx_arch = "gfx942"; }); +} + +TEST(FmhaDispatcherTest, SelectKernelReturnsNullptrOnEmptyRegistry) +{ + FmhaRegistry registry; + FmhaDispatcher dispatcher(®istry); + + fmha_fwd_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_fwd_args{}), "gfx950"); + auto selected = dispatcher.select_kernel(problem); + EXPECT_EQ(selected, nullptr); +} + +TEST(FmhaDispatcherTest, SelectKernelReturnsNullptrOnNoMatch) +{ + FmhaRegistry registry; + auto key = make_fwd_key(128, 128, "fp16", "gfx950"); + auto mock = std::make_shared(key, "fp16_h128"); + registry.register_kernel(mock); + + FmhaDispatcher dispatcher(®istry); + + fmha_fwd_traits traits{}; + traits.hdim_q = 256; + traits.hdim_v = 256; + traits.data_type = "bf16"; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + + auto problem = + FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_fwd_args{}), "gfx950"); + auto selected = dispatcher.select_kernel(problem); + EXPECT_EQ(selected, nullptr); +} diff --git a/dispatcher/tests/test_fmha_kernel_decl.cpp b/dispatcher/tests/test_fmha_kernel_decl.cpp new file mode 100644 index 0000000000..c66a7dfabd --- /dev/null +++ b/dispatcher/tests/test_fmha_kernel_decl.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/dispatcher.hpp" + +using namespace ck_tile::dispatcher; + +DECL_FMHA_KERNEL_SET(decl_test_fmha_kernels, + .add(FmhaSignature() + .family("fwd") + .dtype("fp16") + .mode("batch") + .vlayout("r") + .hdim(128) + .mask("no") + .bias("no"), + FmhaAlgorithm().pipeline("qr_async").tile(128, 128, 32, 128, 32, 128), + "gfx942") + .add(FmhaSignature() + .family("bwd_dq_dk_dv") + .dtype("fp16") + .mode("batch") + .hdim(128) + .mask("no") + .bias("no"), + FmhaAlgorithm().pipeline("qr").tile(128, 128, 32, 128, 32, 128), + "gfx942")); + +int main() +{ + const auto& set = FmhaKernelSetRegistry::instance().get("decl_test_fmha_kernels"); + assert(set.size() == 2); + std::cout << "FMHA decl registry contains " << set.size() << " entries\n"; + return 0; +} diff --git a/dispatcher/tests/test_fmha_parity.py b/dispatcher/tests/test_fmha_parity.py new file mode 100644 index 0000000000..a128b588e4 --- /dev/null +++ b/dispatcher/tests/test_fmha_parity.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +FMHA Parity Test: Dispatcher vs CK Tile 01_fmha vs CPU Reference + +Runs the same test configurations through: + 1. CK Tile tile_example_fmha_fwd (gold standard, if available) + 2. Dispatcher fmha_01_basic (via C++ binary) + 3. Python CPU reference (numpy) + +Compares exit codes and reports parity. + +Usage: + python3 test_fmha_parity.py + python3 test_fmha_parity.py --ck-exe /tmp/ck_fmha_build/bin/tile_example_fmha_fwd +""" + +import sys +import subprocess +import argparse +import os +from pathlib import Path +from dataclasses import dataclass +from typing import Optional + +sys.path.insert(0, str(Path(__file__).parent.parent / "python")) +import numpy as np + +from fmha_utils import FmhaProblem, cpu_attention_fwd, detect_gpu_arch + + +@dataclass +class TestCase: + name: str + prec: str = "fp16" + mode: int = 0 + batch: int = 2 + nhead: int = 2 + nhead_k: int = -1 + hdim: int = 128 + hdim_v: int = -1 + seqlen_q: int = 128 + seqlen_k: int = 128 + bias: str = "n" + mask: str = "0" + lse: int = 0 + p_drop: float = 0.0 + + +PARITY_TESTS = [ + TestCase("basic_fp16"), + TestCase("basic_bf16", prec="bf16"), + TestCase("longer_seq", seqlen_q=256, seqlen_k=256), + TestCase("small_batch", batch=1, nhead=8, seqlen_q=64, seqlen_k=64), + TestCase("gqa_2_1", nhead=4, nhead_k=2), + TestCase("gqa_4_1", nhead=8, nhead_k=2), + TestCase("causal_top_left", mask="1"), + TestCase("causal_bottom_right", mask="2"), + TestCase("bias_elementwise", bias="e"), + TestCase("bias_alibi", bias="a"), + TestCase("with_lse", lse=1), + TestCase("big_batch", batch=4, nhead=8, seqlen_q=128, seqlen_k=128), + TestCase("asymmetric_seq", seqlen_q=64, seqlen_k=256), + TestCase("single_query", batch=1, nhead=4, seqlen_q=1, seqlen_k=128), +] + + +def find_ck_exe() -> Optional[str]: + for path in [ + "/tmp/ck_fmha_build/bin/tile_example_fmha_fwd", + "/workspace/rocm-libraries/projects/composablekernel/build/bin/tile_example_fmha_fwd", + ]: + if os.path.exists(path): + return path + return None + + +def find_dispatcher_exe() -> Optional[str]: + root = Path(__file__).parent.parent + for rel in ["build/examples/fmha_01_basic"]: + p = root / rel + if p.exists(): + return str(p) + return None + + +def run_ck_test(exe: str, tc: TestCase) -> bool: + nhead_k = tc.nhead_k if tc.nhead_k > 0 else tc.nhead + hdim_v = tc.hdim_v if tc.hdim_v > 0 else tc.hdim + cmd = [ + exe, + f"-prec={tc.prec}", + f"-mode={tc.mode}", + f"-b={tc.batch}", + f"-h={tc.nhead}", + f"-h_k={nhead_k}", + f"-d={tc.hdim}", + f"-d_v={hdim_v}", + f"-s={tc.seqlen_q}", + f"-s_k={tc.seqlen_k}", + f"-bias={tc.bias}", + f"-mask={tc.mask}", + f"-lse={tc.lse}", + f"-p_drop={tc.p_drop}", + "-v=1", + "-warmup=0", + "-repeat=1", + ] + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) + return result.returncode == 0 + except (subprocess.TimeoutExpired, FileNotFoundError): + return False + + +def run_dispatcher_test(exe: str, tc: TestCase) -> bool: + cmd = [ + exe, + f"--arch={detect_gpu_arch()}", + f"--batch={tc.batch}", + f"--nhead={tc.nhead}", + f"--seqlen={tc.seqlen_q}", + f"--hdim={tc.hdim}", + "--validate", + ] + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) + return result.returncode == 0 + except (subprocess.TimeoutExpired, FileNotFoundError): + return False + + +def run_cpu_test(tc: TestCase) -> bool: + nhead_k = tc.nhead_k if tc.nhead_k > 0 else tc.nhead + hdim_v = tc.hdim_v if tc.hdim_v > 0 else tc.hdim + prob = FmhaProblem( + batch=tc.batch, + nhead_q=tc.nhead, + nhead_k=nhead_k, + seqlen_q=tc.seqlen_q, + seqlen_k=tc.seqlen_k, + hdim_q=tc.hdim, + hdim_v=hdim_v, + ) + np.random.seed(42) + Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float32) + K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float32) + V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float32) + out = cpu_attention_fwd(Q, K, V, prob.scale) + return out.size > 0 and np.isfinite(out).all() + + +def main(): + parser = argparse.ArgumentParser(description="FMHA Parity Test") + parser.add_argument("--ck-exe", default=None, help="Path to tile_example_fmha_fwd") + parser.add_argument("--dispatcher-exe", default=None, help="Path to fmha_01_basic") + args = parser.parse_args() + + ck_exe = args.ck_exe or find_ck_exe() + disp_exe = args.dispatcher_exe or find_dispatcher_exe() + + print("=" * 80) + print("FMHA Parity Test: CK Tile vs Dispatcher vs CPU Reference") + print("=" * 80) + print(f" CK Tile exe: {ck_exe or 'NOT FOUND'}") + print(f" Dispatcher exe: {disp_exe or 'NOT FOUND'}") + print(f" Test cases: {len(PARITY_TESTS)}") + + header = f" {'#':<3} {'Name':<22} {'CK':>6} {'Disp':>6} {'CPU':>6} {'Parity':>8}" + print(f"\n{header}") + print(" " + "-" * 56) + + total_ck = 0 + total_disp = 0 + total_cpu = 0 + total_parity = 0 + + for i, tc in enumerate(PARITY_TESTS, 1): + ck_ok = run_ck_test(ck_exe, tc) if ck_exe else None + disp_ok = run_dispatcher_test(disp_exe, tc) if disp_exe else None + cpu_ok = run_cpu_test(tc) + + ck_str = "PASS" if ck_ok else ("FAIL" if ck_ok is False else "N/A") + disp_str = "PASS" if disp_ok else ("FAIL" if disp_ok is False else "N/A") + cpu_str = "PASS" if cpu_ok else "FAIL" + + parity = True + if ck_ok is not None and disp_ok is not None: + parity = ck_ok == disp_ok + parity_str = "MATCH" if parity else "DIFF" + + if ck_ok: + total_ck += 1 + if disp_ok: + total_disp += 1 + if cpu_ok: + total_cpu += 1 + if parity: + total_parity += 1 + + print( + f" {i:<3} {tc.name:<22} {ck_str:>6} {disp_str:>6} {cpu_str:>6} {parity_str:>8}" + ) + + print(f"\n{'=' * 80}") + print(f" CK Tile: {total_ck}/{len(PARITY_TESTS)} passed") + print(f" Dispatcher: {total_disp}/{len(PARITY_TESTS)} passed") + print(f" CPU Ref: {total_cpu}/{len(PARITY_TESTS)} passed") + print(f" Parity: {total_parity}/{len(PARITY_TESTS)} matching") + print(f"{'=' * 80}") + + return 0 if total_parity == len(PARITY_TESTS) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/tests/test_fmha_problem.cpp b/dispatcher/tests/test_fmha_problem.cpp new file mode 100644 index 0000000000..deeeb9e5ef --- /dev/null +++ b/dispatcher/tests/test_fmha_problem.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" + +using namespace ck_tile::dispatcher; + +TEST(FmhaProblemTest, BuildsForwardProblemFromInvocation) +{ + fmha_fwd_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_group_mode = false; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.has_dropout = false; + traits.qscale_type = quant_scale_enum::no_scale; + + fmha_fwd_args args{}; + args.batch = 2; + args.seqlen_q = 128; + args.seqlen_k = 256; + args.max_seqlen_q = 128; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 8; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); + EXPECT_TRUE(problem.is_valid()); + EXPECT_EQ(problem.api_family, FmhaApiFamily::Fwd); + EXPECT_EQ(problem.requested_family, FmhaKernelFamily::Fwd); + EXPECT_EQ(problem.data_type, "fp16"); + EXPECT_EQ(problem.hdim_q, 128); + EXPECT_EQ(problem.hdim_v, 128); + EXPECT_EQ(problem.batch, 2); + EXPECT_EQ(problem.seqlen_q, 128); + EXPECT_EQ(problem.seqlen_k, 256); + EXPECT_EQ(problem.nhead_q, 16); + EXPECT_EQ(problem.nhead_k, 8); +} + +TEST(FmhaProblemTest, BuilderCreatesValidProblem) +{ + auto problem = FmhaProblemBuilder() + .api_family(FmhaApiFamily::Fwd) + .kernel_family(FmhaKernelFamily::Fwd) + .gfx_arch("gfx950") + .data_type("fp16") + .dims(128, 128, 2, 256, 512) + .nheads(16, 4) + .mask_type(static_cast(mask_enum::mask_bottom_right)) + .bias_type(static_cast(bias_enum::elementwise_bias)) + .lse(true) + .dropout(false) + .v_rowmajor(true) + .group_mode(false) + .window(128, 0) + .build(); + + EXPECT_TRUE(problem.is_valid()); + EXPECT_EQ(problem.gfx_arch, "gfx950"); + EXPECT_EQ(problem.data_type, "fp16"); + EXPECT_EQ(problem.nhead_q, 16); + EXPECT_EQ(problem.nhead_k, 4); + EXPECT_EQ(problem.mask_type, static_cast(mask_enum::mask_bottom_right)); + EXPECT_EQ(problem.bias_type, static_cast(bias_enum::elementwise_bias)); + EXPECT_TRUE(problem.has_lse); + EXPECT_EQ(problem.window_size_left, 128); +} + +TEST(FmhaProblemTest, NumOpsIsNonZero) +{ + auto problem = FmhaProblemBuilder() + .api_family(FmhaApiFamily::Fwd) + .kernel_family(FmhaKernelFamily::Fwd) + .data_type("fp16") + .dims(128, 128, 2, 256, 512) + .nheads(16, 16) + .build(); + + EXPECT_GT(problem.num_ops(), 0); + // 2*batch*nhead*(sq*sk*dq + sq*sk*dv) = 2*2*16*(256*512*128 + 256*512*128) + std::int64_t expected = 2LL * 2 * 16 * 256 * 512 * (128 + 128); + EXPECT_EQ(problem.num_ops(), expected); +} + +TEST(FmhaProblemTest, ToStringContainsKeyFields) +{ + auto problem = FmhaProblemBuilder() + .api_family(FmhaApiFamily::Fwd) + .data_type("bf16") + .dims(64, 64, 1, 32, 32) + .nheads(8, 8) + .gfx_arch("gfx950") + .build(); + + auto s = problem.to_string(); + EXPECT_NE(s.find("bf16"), std::string::npos); + EXPECT_NE(s.find("gfx950"), std::string::npos); + EXPECT_NE(s.find("fwd"), std::string::npos); +} + +TEST(FmhaProblemTest, TracksSplitKvAndPagedKvFlags) +{ + fmha_fwd_splitkv_traits traits{}; + traits.hdim_q = 128; + traits.hdim_v = 128; + traits.data_type = "fp16"; + traits.is_group_mode = true; + traits.is_v_rowmajor = true; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::no_mask; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = true; + traits.do_fp8_static_quant = false; + + fmha_fwd_splitkv_args args{}; + args.batch = 1; + args.seqlen_q = 64; + args.seqlen_k = 1024; + args.max_seqlen_q = 64; + args.hdim_q = 128; + args.hdim_v = 128; + args.nhead_q = 16; + args.nhead_k = 16; + args.num_splits = 4; + args.block_table_ptr = reinterpret_cast(0x1); + args.page_block_size = 16; + + auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, args), "gfx942"); + EXPECT_TRUE(problem.is_valid()); + EXPECT_EQ(problem.api_family, FmhaApiFamily::FwdSplitKv); + EXPECT_TRUE(problem.use_paged_kv); + EXPECT_TRUE(problem.has_block_table_ptr); + EXPECT_EQ(problem.num_splits, 4); + EXPECT_EQ(problem.page_size, 16); +} diff --git a/dispatcher/tests/test_fmha_registry.cpp b/dispatcher/tests/test_fmha_registry.cpp new file mode 100644 index 0000000000..975dbe7ab6 --- /dev/null +++ b/dispatcher/tests/test_fmha_registry.cpp @@ -0,0 +1,124 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/dispatcher.hpp" + +using namespace ck_tile::dispatcher; + +namespace { + +class StubFmhaKernel : public FmhaKernelInstance +{ + public: + StubFmhaKernel(FmhaKernelKey key, std::string name) + : key_(std::move(key)), name_(std::move(name)) + { + } + + const FmhaKernelKey& get_key() const override { return key_; } + bool supports(const FmhaProblem& problem) const override + { + return key_.signature.family == problem.requested_family && + key_.signature.data_type == problem.data_type; + } + std::string get_name() const override { return name_; } + void launch(const FmhaInvocation&, const ck_tile::stream_config&) const override {} + + private: + FmhaKernelKey key_; + std::string name_; +}; + +FmhaKernelKey +make_stub_key(FmhaKernelFamily family, const std::string& dtype, const std::string& arch) +{ + FmhaKernelKey key; + key.signature.family = family; + key.signature.data_type = dtype; + key.signature.hdim_q = 128; + key.signature.hdim_v = 128; + key.gfx_arch = arch; + key.algorithm.tile_shape = {128, 128, 32, 128, 32, 128}; + key.algorithm.pad_s = true; + key.algorithm.pad_sk = true; + key.algorithm.pad_d = true; + key.algorithm.pad_dv = true; + return key; +} + +} // namespace + +TEST(FmhaRegistryTest, RegisterAndLookup) +{ + FmhaRegistry reg; + auto key = make_stub_key(FmhaKernelFamily::Fwd, "fp16", "gfx950"); + auto kernel = std::make_shared(key, "test_fwd_fp16"); + EXPECT_TRUE(reg.register_kernel(kernel)); + EXPECT_EQ(reg.size(), 1u); + auto found = reg.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "test_fwd_fp16"); +} + +TEST(FmhaRegistryTest, GetAllReturnsSorted) +{ + FmhaRegistry reg; + auto key_a = make_stub_key(FmhaKernelFamily::Fwd, "fp16", "gfx950"); + key_a.algorithm.selection_rank = 1; + auto key_b = make_stub_key(FmhaKernelFamily::BwdDqDkDv, "fp16", "gfx950"); + key_b.algorithm.selection_rank = 0; + + reg.register_kernel(std::make_shared(key_a, "rank1")); + reg.register_kernel(std::make_shared(key_b, "rank0")); + + auto all = reg.get_all(); + ASSERT_EQ(all.size(), 2u); + EXPECT_EQ(all[0]->get_name(), "rank0"); + EXPECT_EQ(all[1]->get_name(), "rank1"); +} + +TEST(FmhaRegistryTest, FilterByArch) +{ + FmhaRegistry reg; + reg.register_kernel(std::make_shared( + make_stub_key(FmhaKernelFamily::Fwd, "fp16", "gfx950"), "k950")); + reg.register_kernel(std::make_shared( + make_stub_key(FmhaKernelFamily::Fwd, "fp16", "gfx942"), "k942")); + EXPECT_EQ(reg.size(), 2u); + + auto removed = reg.filter_by_arch("gfx950"); + EXPECT_EQ(removed, 1u); + EXPECT_EQ(reg.size(), 1u); + EXPECT_NE(reg.lookup(make_stub_key(FmhaKernelFamily::Fwd, "fp16", "gfx950")), nullptr); +} + +TEST(FmhaRegistryTest, FilterByPredicate) +{ + FmhaRegistry reg; + reg.register_kernel(std::make_shared( + make_stub_key(FmhaKernelFamily::Fwd, "fp16", "gfx950"), "fwd_fp16")); + reg.register_kernel(std::make_shared( + make_stub_key(FmhaKernelFamily::Fwd, "bf16", "gfx950"), "fwd_bf16")); + reg.register_kernel(std::make_shared( + make_stub_key(FmhaKernelFamily::BwdDqDkDv, "fp16", "gfx950"), "bwd_fp16")); + + auto fwd_only = reg.filter([](const FmhaKernelInstance& k) { + return k.get_key().signature.family == FmhaKernelFamily::Fwd; + }); + EXPECT_EQ(fwd_only.size(), 2u); +} + +TEST(FmhaRegistryTest, ExportJsonContainsMetadata) +{ + FmhaRegistry reg; + reg.set_name("test_registry"); + reg.register_kernel(std::make_shared( + make_stub_key(FmhaKernelFamily::Fwd, "fp16", "gfx950"), "fwd_fp16")); + + auto json = reg.export_json(); + EXPECT_NE(json.find("test_registry"), std::string::npos); + EXPECT_NE(json.find("total_kernels"), std::string::npos); + EXPECT_NE(json.find("fwd_fp16"), std::string::npos); +} diff --git a/dispatcher/tests/test_fmha_rules.py b/dispatcher/tests/test_fmha_rules.py new file mode 100644 index 0000000000..b2bcd99c09 --- /dev/null +++ b/dispatcher/tests/test_fmha_rules.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import sys +import os +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "codegen")) + +from fmha.validation import validate_config, load_arch_specs + +SPECS = load_arch_specs() + + +def _base_config( + family="fwd", + dtype="fp16", + arch="gfx950", + pipeline="qr_async", + hdim_q=128, + hdim_v=128, + **sig_overrides, +): + sig = { + "family": family, + "data_type": dtype, + "mode": "batch", + "vlayout": "r", + "hdim_q": hdim_q, + "hdim_v": hdim_v, + "mask": "no", + "bias": "no", + "lse": False, + "dropout": False, + "qscale": "no", + "rope": "none", + "logits": False, + "paged_kv": False, + "fp8_static_quant": False, + "skip_min_seqlen_q": False, + "sink": False, + "dbias": False, + "store_randval": False, + "deterministic": False, + "kv_memory_layout": "vectorized", + "kv_lookup_table": "sglang", + "page_size": 1, + } + sig.update(sig_overrides) + alg = { + "pipeline": pipeline, + "tile": [128, 128, 32, 128, 32, 128], + "wave": [4, 1, 1, 4, 1, 1, 1, 1, 1], + "warp": [32, 32, 16, 32, 32, 16, 16, 16, 16], + "padding": [True, True, True, True], + "block_per_cu": 1, + "num_wave_groups": 1, + "max_splits_log2": 0, + "max_seq_len_q": 0, + } + return {"signature": sig, "algorithm": alg, "arch": arch} + + +class TestValidateConfig(unittest.TestCase): + def test_valid_basic_config(self): + r = validate_config(_base_config(), SPECS) + self.assertTrue(r.valid, r.errors) + + def test_unsupported_arch(self): + r = validate_config(_base_config(arch="gfx000"), SPECS) + self.assertFalse(r.valid) + self.assertTrue(any("architecture" in e for e in r.errors)) + + def test_v3_hdim128_valid(self): + r = validate_config(_base_config(pipeline="v3", hdim_q=128, hdim_v=128), SPECS) + self.assertTrue(r.valid, r.errors) + + def test_hdim_not_multiple_of_8(self): + r = validate_config(_base_config(hdim_q=65, hdim_v=128), SPECS) + self.assertFalse(r.valid) + self.assertTrue(any("multiples of 8" in e for e in r.errors)) + + def test_bias_plus_logits_soft_cap(self): + r = validate_config(_base_config(bias="bias", logits=True), SPECS) + self.assertFalse(r.valid) + self.assertTrue(any("logits_soft_cap" in e for e in r.errors)) + + def test_hdim_192_128_with_bias(self): + r = validate_config(_base_config(hdim_q=192, hdim_v=128, bias="bias"), SPECS) + has_issue = any("(192,128)" in e for e in r.errors) or any( + "(192,128)" in w for w in r.warnings + ) + self.assertTrue(has_issue) + + def test_hdim_192_128_with_dropout(self): + r = validate_config(_base_config(hdim_q=192, hdim_v=128, dropout=True), SPECS) + has_issue = any("(192,128)" in e for e in r.errors) or any( + "(192,128)" in w for w in r.warnings + ) + self.assertTrue(has_issue) + + def test_appendkv_must_use_appendkv_pipeline(self): + cfg = _base_config(family="fwd_appendkv", pipeline="qr_async") + r = validate_config(cfg, SPECS) + self.assertFalse(r.valid) + self.assertTrue(any("appendkv pipeline" in e for e in r.errors)) + + def test_pagedkv_requires_qr_pagedkv_pipeline(self): + cfg = _base_config(family="fwd_pagedkv", pipeline="qr_async", paged_kv=True) + r = validate_config(cfg, SPECS) + self.assertFalse(r.valid) + self.assertTrue(any("qr_pagedkv" in e for e in r.errors)) + + def test_batch_prefill_requires_group_mode(self): + cfg = _base_config( + family="batch_prefill", + pipeline="qr_async", + mode="batch", + paged_kv=True, + page_size=64, + ) + cfg["signature"]["mode"] = "batch" + r = validate_config(cfg, SPECS) + self.assertFalse(r.valid) + self.assertTrue(any("group mode" in e for e in r.errors)) + + def test_batch_prefill_valid_group(self): + cfg = _base_config( + family="batch_prefill", pipeline="qr_async", paged_kv=True, page_size=64 + ) + cfg["signature"]["mode"] = "group" + r = validate_config(cfg, SPECS) + self.assertTrue(r.valid, r.errors) + + def test_splitkv_combine_bn1_must_be_32(self): + cfg = _base_config(family="fwd_splitkv_combine", pipeline="qr") + cfg["algorithm"]["tile"][3] = 64 + r = validate_config(cfg, SPECS) + self.assertFalse(r.valid) + self.assertTrue(any("bn1" in e for e in r.errors)) + + def test_bwd_dot_do_o_bm0_128_accepted(self): + cfg = _base_config(family="bwd_dot_do_o", pipeline="qr") + cfg["algorithm"]["tile"][0] = 128 + r = validate_config(cfg, SPECS) + # bwd_dot_do_o with bm0=128 is now valid (relaxed from strict bm0=64) + self.assertTrue(r.valid, r.errors) + + def test_mask_types_all_valid(self): + for mask in ["no", "top_left", "bottom_right", "generic"]: + r = validate_config(_base_config(mask=mask), SPECS) + self.assertTrue(r.valid, f"mask={mask}: {r.errors}") + + +class TestMaskDistinction(unittest.TestCase): + """Verify that top_left and bottom_right are distinct after fix.""" + + def test_mask_canonical_distinguishes(self): + from fmha.symbol_map import canonical_mask, MASK_TO_INT + + self.assertEqual(canonical_mask("top_left"), "top_left") + self.assertEqual(canonical_mask("bottom_right"), "bottom_right") + self.assertNotEqual(MASK_TO_INT["top_left"], MASK_TO_INT["bottom_right"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tile_engine/CMakeLists.txt b/tile_engine/CMakeLists.txt index b713587346..6f4598ad0f 100644 --- a/tile_engine/CMakeLists.txt +++ b/tile_engine/CMakeLists.txt @@ -6,6 +6,7 @@ include_directories(BEFORE ${CMAKE_CURRENT_LIST_DIR}/ops ) +add_subdirectory(ops/fmha EXCLUDE_FROM_ALL) add_subdirectory(ops/gemm EXCLUDE_FROM_ALL) add_subdirectory(ops/gemm_streamk EXCLUDE_FROM_ALL) add_subdirectory(ops/pooling EXCLUDE_FROM_ALL) diff --git a/tile_engine/operation_support_matrix.md b/tile_engine/operation_support_matrix.md index fe852dd1c0..697c829bd3 100644 --- a/tile_engine/operation_support_matrix.md +++ b/tile_engine/operation_support_matrix.md @@ -16,7 +16,7 @@ | GEMM | grouped_gemm_quant | | ❌ | | ❌ | | | | ❌ | | | | ❌ | ❌ | ❌ | ❌ | | Reduce | multi_reduce2d [8]
engine: reduce/
example: 05_reduce/ | ✅ | | ❌ | | | | | | | | | ❌ | ✅ | ✅ | ❌ | | Reduce | reduce2d
example: 05_reduce/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | -| Attention | fmha
example: 01_fmha/ | ❌ | ❌ | ❌ | ❌ | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Attention | fmha
engine: fmha/
example: 01_fmha/ | ✅ | ✅ | ✅ | ❌ | | | | | | | | ✅ | ✅ | ✅ | ❌ | | Attention | sparse_attn
example: 50_sparse_attn/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | | Activation | softmax | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | | Activation | topk_softmax
example: 09_topk_softmax/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | diff --git a/tile_engine/ops/common/parallel_runner.py b/tile_engine/ops/common/parallel_runner.py new file mode 100644 index 0000000000..e4ead184ac --- /dev/null +++ b/tile_engine/ops/common/parallel_runner.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""Generic multi-GPU parallel job runner for tile engine benchmarks. + +Op-agnostic: takes opaque jobs, distributes them across GPUs with one +job per GPU at a time, and yields results in completion order. Used by +fmha_benchmark.py and reusable for gemm/reduce/pooling benchmarks. +""" + +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Callable, Iterator, List, Optional, Tuple + + +def run_parallel_on_gpus( + jobs: List[Any], + gpu_ids: List[int], + run_one: Callable[[Any, int], Any], + max_workers: Optional[int] = None, +) -> Iterator[Tuple[int, Any]]: + """Dispatch jobs across GPUs, one job per GPU at a time. + + Args: + jobs: Opaque job objects passed to run_one. + gpu_ids: GPU IDs to use (e.g. [0,1,2,3]). At most one job per GPU runs concurrently. + run_one: Callable run_one(job, gpu_id) -> result. Caller is responsible + for any subprocess isolation, environment setup, etc. + max_workers: Thread pool size. Defaults to len(gpu_ids). + + Yields: + (job_index, result) tuples in completion order. Caller can sort by + job_index to restore submission order if needed. + """ + if not jobs: + return + if max_workers is None: + max_workers = len(gpu_ids) + + # One job per GPU at a time + gpu_semas = {gid: threading.Semaphore(1) for gid in gpu_ids} + cycle = [0] + cycle_lock = threading.Lock() + + def _pick_gpu() -> int: + with cycle_lock: + gid = gpu_ids[cycle[0] % len(gpu_ids)] + cycle[0] += 1 + return gid + + def _wrapper(job_idx: int, job: Any) -> Tuple[int, Any]: + gid = _pick_gpu() + gpu_semas[gid].acquire() + try: + return job_idx, run_one(job, gid) + finally: + gpu_semas[gid].release() + + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [pool.submit(_wrapper, i, j) for i, j in enumerate(jobs)] + for fut in as_completed(futures): + yield fut.result() diff --git a/tile_engine/ops/fmha/.gitignore b/tile_engine/ops/fmha/.gitignore new file mode 100644 index 0000000000..8974bbf780 --- /dev/null +++ b/tile_engine/ops/fmha/.gitignore @@ -0,0 +1,3 @@ +*.log +build/ +*_build*/ diff --git a/tile_engine/ops/fmha/CMakeLists.txt b/tile_engine/ops/fmha/CMakeLists.txt new file mode 100644 index 0000000000..b064fea0b9 --- /dev/null +++ b/tile_engine/ops/fmha/CMakeLists.txt @@ -0,0 +1,94 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# FMHA Tile Engine -- Pure Python benchmarking via the CK dispatcher. +# No C++ per-kernel targets; all compilation is JIT via the dispatcher. + +set(FMHA_TE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(FMHA_TE_CONFIGS ${FMHA_TE_DIR}/configs) + +include(ProcessorCount) +ProcessorCount(NPROC) +if(NPROC EQUAL 0) + set(NPROC 8) +endif() + +# Use first arch from SUPPORTED_GPU_TARGETS, or fallback to gfx950 +set(FMHA_BENCH_ARCH "gfx950") +if(SUPPORTED_GPU_TARGETS) + list(GET SUPPORTED_GPU_TARGETS 0 FMHA_BENCH_ARCH) +endif() + +# Main benchmark target (runs forward sweep by default) +add_custom_target(benchmark_fmha + COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py + ${FMHA_TE_CONFIGS}/fwd.json + --arch ${FMHA_BENCH_ARCH} + --workers ${NPROC} + --best + --json ${CMAKE_CURRENT_BINARY_DIR}/fmha_fwd_results.json + WORKING_DIRECTORY ${FMHA_TE_DIR} + COMMENT "FMHA tile engine benchmark (forward)" +) + +if(TARGET ck_tile_dispatcher) + add_dependencies(benchmark_fmha ck_tile_dispatcher) +endif() + +# Per-variant convenience targets +foreach(variant fwd bwd splitkv appendkv pagedkv batch_prefill) + if(EXISTS ${FMHA_TE_CONFIGS}/${variant}.json) + add_custom_target(benchmark_fmha_${variant} + COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py + ${FMHA_TE_CONFIGS}/${variant}.json + --arch ${FMHA_BENCH_ARCH} + --workers ${NPROC} + --best + --json ${CMAKE_CURRENT_BINARY_DIR}/fmha_${variant}_results.json + WORKING_DIRECTORY ${FMHA_TE_DIR} + COMMENT "FMHA tile engine benchmark (${variant})" + ) + if(TARGET ck_tile_dispatcher) + add_dependencies(benchmark_fmha_${variant} ck_tile_dispatcher) + endif() + endif() +endforeach() + +# CI target (minimal sweep for quick validation) +if(EXISTS ${FMHA_TE_CONFIGS}/fwd_ci.json) + add_custom_target(benchmark_fmha_ci + COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py + ${FMHA_TE_CONFIGS}/fwd_ci.json + --arch ${FMHA_BENCH_ARCH} + --workers 8 + --verify + WORKING_DIRECTORY ${FMHA_TE_DIR} + COMMENT "FMHA tile engine CI benchmark" + ) + if(TARGET ck_tile_dispatcher) + add_dependencies(benchmark_fmha_ci ck_tile_dispatcher) + endif() +endif() + +# All-variants target +set(FMHA_ALL_CONFIGS "") +foreach(cfg fwd bwd splitkv appendkv pagedkv batch_prefill) + if(EXISTS ${FMHA_TE_CONFIGS}/${cfg}.json) + list(APPEND FMHA_ALL_CONFIGS ${FMHA_TE_CONFIGS}/${cfg}.json) + endif() +endforeach() + +add_custom_target(benchmark_fmha_all + COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py + ${FMHA_ALL_CONFIGS} + --arch ${FMHA_BENCH_ARCH} + --workers ${NPROC} + --best + --json ${CMAKE_CURRENT_BINARY_DIR}/fmha_all_results.json + WORKING_DIRECTORY ${FMHA_TE_DIR} + COMMENT "FMHA tile engine benchmark (all variants)" +) + +if(TARGET ck_tile_dispatcher) + add_dependencies(benchmark_fmha_all ck_tile_dispatcher) +endif() diff --git a/tile_engine/ops/fmha/README.md b/tile_engine/ops/fmha/README.md new file mode 100644 index 0000000000..881b2b2ef8 --- /dev/null +++ b/tile_engine/ops/fmha/README.md @@ -0,0 +1,192 @@ +# FMHA Tile Engine + +Benchmarking and kernel enumeration for Fused Multi-Head Attention (FMHA) via the CK dispatcher's pipelined JIT compilation. + +Covers all 9 FMHA kernel families: Forward, Split-KV (main + combine), Paged-KV, Append-KV, Batch Prefill, and Backward (dot\_do\_o, dq\_dk\_dv, convert\_dq) -- totaling 33,541 unique kernel specializations on gfx950. + +## Directory Layout + +``` +fmha/ + fmha_instance_builder.py Kernel enumeration from JSON config + pipeline rules + fmha_benchmark.py Single-config JIT compile and GPU benchmark + fmha_full_benchmark.py Full sweep: compile all kernels, benchmark across test shapes + ck_fmha_testing_matrix.yaml Test shapes (smoke / full / nightly) + CMakeLists.txt CMake targets + README.md This file + configs/ Sweep definitions (JSON) + receipt0_fwd.json Full receipt-0 forward: ~12K kernels + fwd.json Forward variants + fwd_ci.json Minimal CI subset + bwd.json Backward variants + splitkv.json Split-KV + appendkv.json Append-KV + pagedkv.json Paged-KV + batch_prefill.json Batch prefill + filters/ Sample Python filter scripts + h128_no_dropout.py Keep only h128 without dropout +``` + +## Quick Start + +```bash +# Count kernels without compiling +python fmha_instance_builder.py configs/receipt0_fwd.json --count-only + +# Minimal CI build + run (~16 kernels, <1 min) +python fmha_benchmark.py configs/fwd_ci.json --workers 128 --verify + +# Full forward receipt-0 compile-only (12K kernels, ~10 min with 256 workers) +python fmha_benchmark.py configs/receipt0_fwd.json --workers 256 --compile-only + +# Full sweep: compile every fwd kernel, benchmark against all smoke shapes +python fmha_full_benchmark.py --category smoke --variant fwd --workers 256 + +# Quick end-to-end test (2 kernels, 1 shape) +python fmha_full_benchmark.py --category smoke --variant fwd --max-kernels 2 --workers 4 +``` + +## How It Works + +### Kernel Enumeration + +``` +JSON config (variant + trait_config allow-list) + --> fmha_instance_builder.py + --> fmha_pipeline_rules.py (self-contained CK parity logic) + --> fmha_arch_specs.json (tile tables per arch / dtype / hdim) + --> list of FmhaKernelConfig (33,541 total on gfx950) + --> optional --filter / --filter-file +``` + +The pipeline rules in `dispatcher/codegen/fmha_pipeline_rules.py` reproduce the exact kernel enumeration from CK Tile's `01_fmha/codegen/`, including per-arch tile constraints, pipeline selection, padding variants, and feature products. Parity is verified by `dispatcher/tests/validate_arch_specs_parity.py`. + +### Benchmark Tools + +**`fmha_benchmark.py`** -- single-config benchmark. Input: one JSON config (kernel definitions). JIT-compiles all matching kernels, runs each on a given problem size, reports per-kernel timing and optional CPU validation. Optionally writes `--csv` output. + +**`fmha_full_benchmark.py`** -- full sweep benchmark. Input: `ck_fmha_testing_matrix.yaml` (test shapes) + JSON configs (kernel definitions). Compiles all kernel variants for selected families, then iterates over test shapes, matching each shape to compatible compiled kernels and benchmarking every match. Writes `--csv` and `--json` output. + +### JIT Compilation Pipeline + +Both tools use the dispatcher's `setup_multiple_fmha_dispatchers()` which implements a 3-stage pipelined build: + +1. **Codegen** (parallel) -- generate C++ kernel specializations and ctypes wrappers +2. **Compile** (parallel) -- `hipcc` compile each kernel and ctypes lib +3. **Link + Load** (parallel) -- produce `.so` libraries, load via ctypes + +With 256 workers, throughput is roughly 5-10 kernels/sec depending on kernel complexity. + +## JSON Config Format + +Each config specifies a `variant` and an optional `trait_config` that acts as an allow-list filter: + +```json +{ + "variant": "fwd", + "trait_config": { + "data_type": {"values": ["fp16", "bf16"]}, + "pipeline": {"values": ["qr_async"]}, + "mode": {"values": ["batch"]}, + "mask": {"values": ["no"]}, + "bias": {"values": ["no"]}, + "lse": {"values": [false]}, + "dropout": {"values": [false]}, + "logits": {"values": [false]}, + "sink": {"values": [false]} + } +} +``` + +If a trait key is absent, all values pass. The `receipt0_fwd.json` config only restricts `data_type` to exclude fp32, giving the full ~12K forward kernel set. + +## Filtering + +### CLI expression + +```bash +python fmha_benchmark.py configs/receipt0_fwd.json \ + --filter "c.hdim_q == 128 and c.pipeline == 'qr_async'" + +python fmha_full_benchmark.py --variant fwd \ + --filter "c.hdim_q == 128 and c.hdim_v == 128 and c.data_type == 'fp16'" +``` + +The expression accesses `c` (an `FmhaKernelConfig` dataclass) with fields: `data_type`, `mode`, `hdim_q`, `hdim_v`, `pipeline`, `tile_m0`, `tile_n0`, `tile_k0`, `pad_s`, `pad_sk`, `pad_d`, `pad_dv`, `mask`, `bias`, `lse`, `dropout`, `logits`, `sink`, `skip_min_seqlen_q`, `qscale`, `paged_kv`, `rope`, `deterministic`, `dbias`, `dropout_variant`. + +### Python file filter + +```bash +python fmha_benchmark.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py +``` + +The file must define `filter_config(c) -> bool`. Both `--filter` and `--filter-file` combine with AND logic. + +## Test Shape Matrix + +`ck_fmha_testing_matrix.yaml` defines test problems in three tiers: + +| Category | Purpose | Shapes | +|----------|---------|--------| +| `smoke` | Pre-submit sanity, <5 min | ~365 | +| `full` | Post-submit validation | smoke + ~1,500 | +| `nightly`| Exhaustive sweep | all | + +Shapes cover representative configurations: GQA ratios, asymmetric head dims, non-power-of-2 sequences, FP8 variants, long sequences, and cross-attention patterns. + +## Output Format + +### CSV + +``` +problem_name,batch,seqlen_q,seqlen_k,nhead_q,nhead_k,hdim_q,hdim_v,dtype, +kernel,family,mode,pipeline,tile_m0,tile_n0,tile_k0,..., +latency_ms,tflops,bandwidth_gb_s +``` + +Every column needed to fully reconstruct the kernel identity is included. TFLOPS and latency come directly from CK's internal HIP event timing. + +### JSON + +```json +{ + "metadata": { + "arch": "gfx950", + "category": "smoke", + "total_kernels": 600, + "shapes_benchmarked": 42, + "total_measurements": 12600 + }, + "results": [...] +} +``` + +## CMake Targets + +```bash +make benchmark_fmha # Forward sweep +make benchmark_fmha_ci # Quick CI validation +make benchmark_fmha_bwd # Backward sweep +make benchmark_fmha_all # All variants +make benchmark_fmha_splitkv # Split-KV only +``` + +## Parity Verification + +```bash +python dispatcher/tests/validate_arch_specs_parity.py --arch gfx950 --receipt 0 +# PASS: 33,541 kernels across all 9 families +``` + +This confirms the dispatcher's self-contained enumeration exactly matches CK Tile's upstream codegen. + +## Example: Single-Shape All-Kernel Benchmark + +Run every compiled fwd fp16 h128 kernel against one shape: + +```bash +python fmha_full_benchmark.py \ + --category smoke --variant fwd --workers 256 \ + --filter "c.hdim_q == 128 and c.hdim_v == 128 and c.data_type == 'fp16'" \ + --csv results.csv +``` diff --git a/tile_engine/ops/fmha/ck_fmha_testing_matrix.yaml b/tile_engine/ops/fmha/ck_fmha_testing_matrix.yaml new file mode 100644 index 0000000000..a97a4bb59a --- /dev/null +++ b/tile_engine/ops/fmha/ck_fmha_testing_matrix.yaml @@ -0,0 +1,788 @@ +test_categories: + Smoke: + description: "Pre-submit sanity checks. Fast execution, covering basic functionality and edge cases." + test_patterns: + - "*/Smoke.*" + labels: ["Smoke"] + + Full: + description: "Post-submit validation. Comprehensive coverage of modern LLM architectures and CK operational constraints." + test_patterns: + - "*/Smoke.*" + - "*/Full.*" + labels: ["Full"] + + Nightly: + description: "Nightly exhaustive coverage. Sweeps all combinations of precision, layout, masking, and padding." + test_patterns: + - "*" + labels: ["Nightly"] + +execution_settings: + default_timeout: 60 + category_timeouts: + Smoke: 60 # 1 min per test + Full: 300 # 5 min per test + Nightly: 600 # 10 min per test + +# ============================================================================= +# Forward Pass (Prefill) & Stochastic Execution (Dropout) +# ============================================================================= +forward_tests: + # --------------------------------------------------------------------------- + # Smoke Tests (Fast, representative subset) + # --------------------------------------------------------------------------- + smoke: + - name: "GQA_4to1_Prefill_Basic" + description: "Baseline GQA prefill; primary optimization target." + batch: [1, 4] + seqlen_q: [2048] + seqlen_k: [2048] + nhead_q: [32] + nhead_k: [8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false, true] + + - name: "Small_GQA_7to1_SubWarp" + description: "Sub-warp vectorized loads; low LDS utilization bounds." + batch: [1] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [14] + nhead_k: [2] + hdim_q: [64] + hdim_v: [64] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "MHA_H96_Irregular_Dim" + description: "Non-power-of-2 hdim; forces complex padding/striding in LDS." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [32] + nhead_k: [32] + hdim_q: [96] + hdim_v: [96] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + # CK smoke test edge cases (from example/ck_tile/01_fmha/script/smoke_test_fwd.sh) + - name: "CK_Asymmetric_Hdim_Small" + description: "Asymmetric hdim_q != hdim_v; tests vectorized load widths." + batch: [2] + seqlen_q: [55] + seqlen_k: [256] + nhead_q: [2] + nhead_k: [1] + hdim_q: [16] + hdim_v: [32, 64, 128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "CK_Tiny_Sequences" + description: "Edge cases: sq=1, sq=3, very short sequences." + batch: [1, 2] + seqlen_q: [1, 3, 33] + seqlen_k: [10, 99, 33] + nhead_q: [2] + nhead_k: [1] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "CK_Asymmetric_Seqlen" + description: "Asymmetric seqlen_q != seqlen_k from CK smoke tests." + batch: [1, 2] + seqlen_q: [100, 99, 1024] + seqlen_k: [51, 256, 256] + nhead_q: [3] + nhead_k: [3] + hdim_q: [64, 128] + hdim_v: [64, 128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + # Hdim sweep covering all supported (hdim_q, hdim_v) pairs. + # YAML cartesian product creates some orphan combos (hdim_q != hdim_v pairs + # without kernels). The benchmark silently skips these. Use --validate to list them. + # Supported pairs: h32, h64, h80x96, h96, h96x128, h128, h160, h192x128, h192, h256 + - name: "CK_All_Hdim_Sweep" + description: "Sweep all supported hdim combos. Orphan pairs are skipped at runtime." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [8] + nhead_k: [4] + hdim_q: [32, 64, 80, 96, 128, 160, 192, 256] + hdim_v: [32, 64, 96, 128, 160, 192, 256] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "CK_FP8_Basic" + description: "FP8 basic forward test." + batch: [1, 2] + seqlen_q: [128] + seqlen_k: [128] + nhead_q: [1] + nhead_k: [1] + hdim_q: [64, 128, 192, 256] + hdim_v: [64, 128, 128, 256] + dtype: ["fp8bf16", "fp8fp32"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + # Production model configs (from aiter model_shapes.json) + - name: "GQA_16to1_Large" + description: "16:1 GQA ratio (70B-class models)." + batch: [1, 4] + seqlen_q: [2048] + seqlen_k: [2048] + nhead_q: [64] + nhead_k: [4] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "MQA_128to8_Decode" + description: "405B-class decode: 128 Q heads, 8 KV heads, single token query." + batch: [1, 8, 64] + seqlen_q: [1] + seqlen_k: [1024, 4096] + nhead_q: [128] + nhead_k: [8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "MLA_Sparse_Decode" + description: "Multi-latent attention decode (R1-class): asymmetric h192x128." + batch: [1, 4] + seqlen_q: [1] + seqlen_k: [1024, 4096] + nhead_q: [128] + nhead_k: [128] + hdim_q: [192] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Vision_Transformer_Shapes" + description: "Vision-text hybrid (Maverick-class): h88 and h128 mixed." + batch: [1, 4] + seqlen_q: [256, 1024] + seqlen_k: [256, 1024] + nhead_q: [16, 40] + nhead_k: [8, 16] + hdim_q: [88, 128] + hdim_v: [88, 128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "FP8_Varlen_Realistic" + description: "FP8 with realistic GQA and variable lengths (from aiter tests)." + batch: [1, 8] + seqlen_q: [113, 256, 1024] + seqlen_k: [203, 512, 1024] + nhead_q: [8, 32, 40] + nhead_k: [1, 8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp8bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Extreme_GQA_Ratios" + description: "Extreme GQA: 5:1, 10:1, 24:4, 48:8 from aiter test suite." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [5, 10, 24, 48] + nhead_k: [1, 1, 4, 8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Paged_Decode_Shapes" + description: "Paged attention decode patterns: single-token Q, long KV context." + batch: [4, 80, 128] + seqlen_q: [1, 4] + seqlen_k: [512, 4096] + nhead_q: [8, 16, 64] + nhead_k: [1, 4] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Prefill_Odd_Lengths" + description: "Prefill with non-standard seq lengths from aiter test suite." + batch: [2] + seqlen_q: [113, 339, 799, 1023, 3131] + seqlen_k: [203, 339, 799, 1024, 3131] + nhead_q: [32] + nhead_k: [8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + # --------------------------------------------------------------------------- + # Full Tests (Modern LLM Architectures & CK Constraints) + # --------------------------------------------------------------------------- + full: + - name: "MHA_H256_High_LDS_Pressure" + description: "High LDS pressure; tests block partitioner limits with hdim=256." + batch: [1, 4] + seqlen_q: [4096] + seqlen_k: [4096] + nhead_q: [8] + nhead_k: [4] + hdim_q: [256] + hdim_v: [256] + dtype: ["bf16"] + layout: ["BHSD", "BSHD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [true] + + - name: "MQA_64to1_Broadcast" + description: "Pure MQA; tests extreme KV to Q broadcast logic (64:1)." + batch: [2] + seqlen_q: [4096] + seqlen_k: [4096] + nhead_q: [64] + nhead_k: [1] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "GQA_6to1_Irregular" + description: "Irregular 6:1 GQA ratio; tests tile distribution." + batch: [2] + seqlen_q: [4096] + seqlen_k: [4096] + nhead_q: [48] + nhead_k: [8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "MLA_H128xH576_Asymmetric" + description: "Multi-latent attention fusion; asymmetric Q/KV (128 vs 576)." + batch: [1, 4] + seqlen_q: [4096] + seqlen_k: [4096] + nhead_q: [128] + nhead_k: [128] + hdim_q: [128] + hdim_v: [576] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + lse: [true] + + - name: "Asymmetric_Head_Dims_192_128" + description: "Test asymmetric head dimensions (192x128)." + batch: [2] + seqlen_q: [2048] + seqlen_k: [2048] + nhead_q: [16] + nhead_k: [16] + hdim_q: [192] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD", "BSHD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Asymmetric_Head_Dims_128_192" + description: "Test asymmetric head dimensions (128x192)." + batch: [2] + seqlen_q: [2048] + seqlen_k: [2048] + nhead_q: [16] + nhead_k: [16] + hdim_q: [128] + hdim_v: [192] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Diverse_Head_Dims_Sweep" + description: "Sweep across various head dimensions to ensure broad coverage." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [16] + nhead_k: [16] + hdim_q: [48, 64, 72, 96, 128, 160, 256] + hdim_v: [48, 64, 72, 96, 128, 160, 256] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Stochastic_Execution_Dropout_Sweep" + description: "PRNG state synchronization and warp alignment with stochastic masking across dims." + batch: [4] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [16] + nhead_k: [8] + hdim_q: [48, 64, 72, 96, 128, 160, 256] + hdim_v: [48, 64, 72, 96, 128, 160, 256] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.1, 0.2] + lse: [false, true] + + - name: "Padding_Boundary_Stress_Odd_Lengths" + description: "Test sequences that are not perfect multiples of the tile size to validate padding logic." + batch: [2] + seqlen_q: [259, 500, 987, 1023] + seqlen_k: [259, 500, 987, 1023] + nhead_q: [16] + nhead_k: [16] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Bias_Variants_Sweep" + description: "Test elementwise and alibi bias across different sequence lengths and batch sizes." + batch: [1, 4] + seqlen_q: [512, 1024] + seqlen_k: [512, 1024] + nhead_q: [16] + nhead_k: [16] + hdim_q: [64, 128] + hdim_v: [64, 128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["elementwise", "alibi"] + dropout: [0.0] + lse: [false] + + - name: "Extreme_Batch_Size_Stress" + description: "Test very large batch sizes to stress grid launch dimensions and scheduling." + batch: [64, 128, 256] + seqlen_q: [128] + seqlen_k: [128] + nhead_q: [8] + nhead_k: [8] + hdim_q: [64] + hdim_v: [64] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + lse: [false] + + - name: "Long_Sequence_Stress" + description: "Test very long sequences (approaching split-KV territory but forced dense)." + batch: [1] + seqlen_q: [8192, 16384] + seqlen_k: [8192, 16384] + nhead_q: [16] + nhead_k: [4] + hdim_q: [128] + hdim_v: [128] + dtype: ["bf16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + lse: [true] + + - name: "Cross_Attention_Shapes" + description: "Test shapes typical of cross-attention where seqlen_q != seqlen_k." + batch: [2] + seqlen_q: [1, 32, 128] + seqlen_k: [1024, 4096] + nhead_q: [16] + nhead_k: [16] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + + - name: "CK_Benchmark_Standard" + description: "Standard CK benchmark sweep (from benchmark_fwd.sh)." + batch: [32, 16, 8, 4, 2, 1] + seqlen_q: [512, 1024, 2048, 4096, 8192, 16384] + seqlen_k: [512, 1024, 2048, 4096, 8192, 16384] + nhead_q: [32, 16, 8] + nhead_k: [32, 16, 8] + hdim_q: [64, 128, 256] + hdim_v: [64, 128, 256] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + + - name: "CK_Benchmark_V3_Large" + description: "V3 pipeline benchmark with very long sequences (from benchmark_fwd_v3.sh)." + batch: [1] + seqlen_q: [16384, 37200, 65536] + seqlen_k: [16384, 37200, 65536] + nhead_q: [16, 40, 64] + nhead_k: [1, 16, 40, 64] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + lse: [false] + +# ============================================================================= +# Backward Pass (Gradient Computation) +# ============================================================================= +backward_tests: + # --------------------------------------------------------------------------- + # Smoke Tests + # --------------------------------------------------------------------------- + smoke: + - name: "Bwd_Basic_No_Features" + description: "Basic backward pass without optional features." + batch: [1, 2] + seqlen_q: [512] + seqlen_k: [512] + nhead_q: [16] + nhead_k: [16] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_GQA_Smoke" + description: "Backward GQA smoke test (4:1 and 8:1 ratios)." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [32] + nhead_k: [8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_Hdim_Sweep_Smoke" + description: "Backward across key head dimensions." + batch: [2] + seqlen_q: [512] + seqlen_k: [512] + nhead_q: [8] + nhead_k: [8] + hdim_q: [64, 96, 128, 256] + hdim_v: [64, 96, 128, 256] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_With_Mask_Dropout" + description: "Backward with causal mask and dropout." + batch: [2] + seqlen_q: [512] + seqlen_k: [512] + nhead_q: [16] + nhead_k: [16] + hdim_q: [64, 128] + hdim_v: [64, 128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.1] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_Asymmetric_Hdim_Smoke" + description: "Backward with asymmetric head dimensions." + batch: [2] + seqlen_q: [512] + seqlen_k: [512] + nhead_q: [16] + nhead_k: [16] + hdim_q: [192] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + # --------------------------------------------------------------------------- + # Full Tests + # --------------------------------------------------------------------------- + full: + - name: "Bwd_GQA_Support" + description: "Backward pass with Grouped Query Attention." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [32, 64] + nhead_k: [8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_High_Capacity_H256" + description: "Backward pass with hdim=256; high LDS pressure." + batch: [1] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [8] + nhead_k: [4] + hdim_q: [256] + hdim_v: [256] + dtype: ["bf16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_Irregular_H96" + description: "Backward pass with non-power-of-2 hdim." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [32] + nhead_k: [32] + hdim_q: [96] + hdim_v: [96] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_All_Features_Enabled" + description: "Backward pass with bias gradients, dropout, and deterministic accumulation." + batch: [2] + seqlen_q: [512] + seqlen_k: [512] + nhead_q: [16] + nhead_k: [16] + hdim_q: [48, 64, 72, 96, 128, 160, 256] + hdim_v: [48, 64, 72, 96, 128, 160, 256] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["elementwise", "alibi"] + dropout: [0.1] + has_dbias: [true] + is_deterministic: [true] + + - name: "Bwd_Padding_Boundary_Stress" + description: "Test backward pass with sequences that are not perfect multiples of the tile size." + batch: [1] + seqlen_q: [259, 500, 1023] + seqlen_k: [259, 500, 1023] + nhead_q: [8] + nhead_k: [8] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask", "top_left"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_Asymmetric_Head_Dims_192_128" + description: "Test backward pass with asymmetric head dimensions (192x128)." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [16] + nhead_k: [16] + hdim_q: [192] + hdim_v: [128] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["top_left"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_Asymmetric_Head_Dims_128_192" + description: "Test backward pass with asymmetric head dimensions (128x192)." + batch: [2] + seqlen_q: [1024] + seqlen_k: [1024] + nhead_q: [16] + nhead_k: [16] + hdim_q: [128] + hdim_v: [192] + dtype: ["fp16", "bf16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_Diverse_Head_Dims_Sweep" + description: "Sweep backward pass across various head dimensions." + batch: [2] + seqlen_q: [512] + seqlen_k: [512] + nhead_q: [16] + nhead_k: [16] + hdim_q: [48, 64, 72, 96, 128, 160, 256] + hdim_v: [48, 64, 72, 96, 128, 160, 256] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] + + - name: "Bwd_Cross_Attention_Shapes" + description: "Test shapes typical of cross-attention where seqlen_q != seqlen_k in backward." + batch: [2] + seqlen_q: [1, 32, 128] + seqlen_k: [1024, 4096] + nhead_q: [16] + nhead_k: [16] + hdim_q: [128] + hdim_v: [128] + dtype: ["fp16"] + layout: ["BHSD"] + mask: ["no_mask"] + bias: ["none"] + dropout: [0.0] + has_dbias: [false] + is_deterministic: [false] diff --git a/tile_engine/ops/fmha/configs/appendkv.json b/tile_engine/ops/fmha/configs/appendkv.json new file mode 100644 index 0000000000..21a8a53a4e --- /dev/null +++ b/tile_engine/ops/fmha/configs/appendkv.json @@ -0,0 +1,6 @@ +{ + "variant": "appendkv", + "trait_config": { + "data_type": {"values": ["fp16", "bf16", "fp8"]} + } +} diff --git a/tile_engine/ops/fmha/configs/batch_prefill.json b/tile_engine/ops/fmha/configs/batch_prefill.json new file mode 100644 index 0000000000..c8cf1899e3 --- /dev/null +++ b/tile_engine/ops/fmha/configs/batch_prefill.json @@ -0,0 +1,6 @@ +{ + "variant": "batch_prefill", + "trait_config": { + "data_type": {"values": ["fp16", "bf16", "fp8bf16"]} + } +} diff --git a/tile_engine/ops/fmha/configs/bwd.json b/tile_engine/ops/fmha/configs/bwd.json new file mode 100644 index 0000000000..af4b1a8beb --- /dev/null +++ b/tile_engine/ops/fmha/configs/bwd.json @@ -0,0 +1,6 @@ +{ + "variant": "bwd", + "trait_config": { + "data_type": {"values": ["fp16", "bf16"]} + } +} diff --git a/tile_engine/ops/fmha/configs/fwd.json b/tile_engine/ops/fmha/configs/fwd.json new file mode 100644 index 0000000000..0201a10571 --- /dev/null +++ b/tile_engine/ops/fmha/configs/fwd.json @@ -0,0 +1,9 @@ +{ + "variant": "fwd", + "trait_config": { + "data_type": {"values": ["fp16", "bf16"]}, + "pipeline": {"values": ["qr", "qr_async"]}, + "mask": {"values": ["no", "top_left"]}, + "bias": {"values": ["no"]} + } +} diff --git a/tile_engine/ops/fmha/configs/fwd_ci.json b/tile_engine/ops/fmha/configs/fwd_ci.json new file mode 100644 index 0000000000..435dca8d23 --- /dev/null +++ b/tile_engine/ops/fmha/configs/fwd_ci.json @@ -0,0 +1,14 @@ +{ + "variant": "fwd", + "trait_config": { + "data_type": {"values": ["fp16"]}, + "pipeline": {"values": ["qr_async"]}, + "mask": {"values": ["no"]}, + "bias": {"values": ["no"]}, + "mode": {"values": ["batch"]}, + "lse": {"values": [false]}, + "dropout": {"values": [false]}, + "logits": {"values": [false]}, + "sink": {"values": [false]} + } +} diff --git a/tile_engine/ops/fmha/configs/pagedkv.json b/tile_engine/ops/fmha/configs/pagedkv.json new file mode 100644 index 0000000000..7db1e45f4d --- /dev/null +++ b/tile_engine/ops/fmha/configs/pagedkv.json @@ -0,0 +1,6 @@ +{ + "variant": "pagedkv", + "trait_config": { + "data_type": {"values": ["fp16", "bf16", "fp8"]} + } +} diff --git a/tile_engine/ops/fmha/configs/receipt0_fwd.json b/tile_engine/ops/fmha/configs/receipt0_fwd.json new file mode 100644 index 0000000000..ff3fc59f48 --- /dev/null +++ b/tile_engine/ops/fmha/configs/receipt0_fwd.json @@ -0,0 +1,6 @@ +{ + "variant": "fwd", + "trait_config": { + "data_type": {"values": ["fp16", "bf16", "fp8bf16", "fp8fp32"]} + } +} diff --git a/tile_engine/ops/fmha/configs/splitkv.json b/tile_engine/ops/fmha/configs/splitkv.json new file mode 100644 index 0000000000..930121c9f6 --- /dev/null +++ b/tile_engine/ops/fmha/configs/splitkv.json @@ -0,0 +1,6 @@ +{ + "variant": "splitkv", + "trait_config": { + "data_type": {"values": ["fp16", "bf16", "fp8"]} + } +} diff --git a/tile_engine/ops/fmha/filters/h128_no_dropout.py b/tile_engine/ops/fmha/filters/h128_no_dropout.py new file mode 100644 index 0000000000..aa9b2d9ef3 --- /dev/null +++ b/tile_engine/ops/fmha/filters/h128_no_dropout.py @@ -0,0 +1,14 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""Sample filter: only h128 kernels without dropout. + +Usage: + python fmha_benchmark.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py + python fmha_instance_builder.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py --count-only +""" + + +def filter_config(c) -> bool: + """Keep only h128 kernels without dropout.""" + return c.hdim_q == 128 and not c.dropout diff --git a/tile_engine/ops/fmha/fmha_benchmark.py b/tile_engine/ops/fmha/fmha_benchmark.py new file mode 100644 index 0000000000..052ed232d9 --- /dev/null +++ b/tile_engine/ops/fmha/fmha_benchmark.py @@ -0,0 +1,939 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +FMHA tile engine benchmark runner. + +Uses the dispatcher's setup_multiple_fmha_dispatchers() for pipelined JIT +compilation, then runs GPU benchmarks and reports results. + +Usage: + python fmha_benchmark.py configs/fwd.json + python fmha_benchmark.py configs/fwd.json --workers 256 --build-dir /tmp/fmha_build + python fmha_benchmark.py configs/fwd.json --problems "2,8,1024,128" --verify +""" + +import argparse +import csv +import json +import os +import shutil +import sys +import time +from pathlib import Path +from typing import List + +import numpy as np + +_DISPATCHER_ROOT = Path(__file__).resolve().parents[3] / "dispatcher" +sys.path.insert(0, str(_DISPATCHER_ROOT / "python")) +sys.path.insert(0, str(_DISPATCHER_ROOT / "codegen")) + +from fmha_utils import ( # noqa: E402 + FmhaProblem, + FmhaRunner, + cpu_attention_fwd, + detect_gpu_arch, + setup_multiple_fmha_dispatchers, +) + +from fmha.instance_gen import expand_sweep, apply_filter # noqa: E402 + +# Reusable multi-GPU job dispatcher (op-agnostic) +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "common")) +from parallel_runner import run_parallel_on_gpus # noqa: E402 + + +def _compute_result( + config, + prob, + time_ms, + output, + ref, + is_causal, + ns, + api_family, + dtype_tol, + gpu_id=None, +): + """Compute tflops, max_err, status and build result dict + display line. + + Returns (result_dict, display_line) or None if time_ms is None/0. + """ + tflops = prob.num_ops / (time_ms * 1e-3) / 1e12 if time_ms > 0 else 0 + if is_causal and time_ms > 0: + sq, sk = prob.seqlen_q, prob.seqlen_k + causal_ratio = (min(sq, sk) + 1) / (2.0 * sk) + tflops = prob.num_ops * causal_ratio / (time_ms * 1e-3) / 1e12 + + max_err = 0.0 + status = "OK" + if ref is not None and output is not None: + max_err = float(np.abs(output.astype(np.float32) - ref).max()) + atol, rtol = dtype_tol + tol = atol + rtol * np.abs(ref).max() + status = "PASS" if max_err < tol else "FAIL" + + splits_tag = f" [ns={ns}]" if api_family == "splitkv" else "" + display_name = f"{config.name}{splits_tag}" + gpu_tag = f" [GPU{gpu_id}]" if gpu_id is not None else "" + display_line = ( + f" {display_name:<105} {time_ms:>10.3f}" + f" {tflops:>10.2f} {max_err:>10.2e} {status:>6}{gpu_tag}" + ) + + result_dict = { + "kernel": config.name, + "dtype": config.data_type, + "hdim_q": config.hdim_q, + "hdim_v": config.hdim_v, + "pipeline": config.pipeline, + "mode": config.mode, + "mask": config.mask, + "bias": config.bias, + "tile_m0": config.tile_m0, + "tile_n0": config.tile_n0, + "tile_k0": config.tile_k0, + "tile_n1": config.tile_n1, + "tile_k1": config.tile_k1, + "tile_k0max": config.tile_k0max, + "warp_m0": config.warp_m0, + "warp_n0": config.warp_n0, + "warp_k0": config.warp_k0, + "block_per_cu": config.block_per_cu, + "num_splits": ns if api_family == "splitkv" else None, + "problem": { + "batch": prob.batch, + "nhead_q": prob.nhead_q, + "nhead_k": prob.nhead_k, + "seqlen_q": prob.seqlen_q, + "seqlen_k": prob.seqlen_k, + "hdim_q": prob.hdim_q, + "hdim_v": prob.hdim_v, + }, + "latency_ms": time_ms, + "tflops": tflops, + "max_err": max_err, + "status": status, + } + return result_dict, display_line + + +def _run_kernel_isolated( + lib_path, arch, prob, run_kwargs, data_dir, gpu_id=0, timeout=120 +): + """Run a single kernel in a subprocess. Returns (time_ms, output_path) or (None, error_msg). + + Survives GPU faults — if the subprocess crashes, returns an error instead of killing main. + """ + import json as _json + import subprocess as sp + + # Write a small runner script that the subprocess will execute. + # Use json.dumps for string values to safely escape quotes/backslashes in paths. + _lib = _json.dumps(str(lib_path)) + _arch = _json.dumps(str(arch)) + _pydir = _json.dumps(str(_DISPATCHER_ROOT / "python")) + _ddir = _json.dumps(str(data_dir)) + script = f''' +import sys, os, numpy as np +os.environ["HIP_VISIBLE_DEVICES"] = "{gpu_id}" +sys.path.insert(0, {_pydir}) +from fmha_utils import FmhaRunner, FmhaProblem + +runner = FmhaRunner.from_library({_lib}, {_arch}) +_d = {_ddir} +Q = np.load(os.path.join(_d, "Q.npy")) +K = np.load(os.path.join(_d, "K.npy")) +V = np.load(os.path.join(_d, "V.npy")) +prob = FmhaProblem(batch={prob.batch}, nhead_q={prob.nhead_q}, nhead_k={prob.nhead_k}, + seqlen_q={prob.seqlen_q}, seqlen_k={prob.seqlen_k}, + hdim_q={prob.hdim_q}, hdim_v={prob.hdim_v}) +result = runner.run(Q, K, V, prob, **{run_kwargs!r}) +if result.success: + np.save(os.path.join(_d, "O.npy"), result.output) + print(f"TIME={{result.time_ms}}") +else: + print("FAIL") +runner.cleanup() +''' + script_path = os.path.join(data_dir, "run_kernel.py") + with open(script_path, "w") as f: + f.write(script) + + try: + r = sp.run( + [sys.executable, script_path], + capture_output=True, + text=True, + timeout=timeout, + env={**os.environ, "HIP_VISIBLE_DEVICES": str(gpu_id)}, + ) + if r.returncode != 0: + err = r.stderr[-200:] if r.stderr else f"exit code {r.returncode}" + return None, None, f"CRASH: {err.strip()}" + # Parse time from stdout + for line in r.stdout.strip().split("\n"): + if line.startswith("TIME="): + time_ms = float(line[5:]) + out_path = os.path.join(data_dir, "O.npy") + output = np.load(out_path) if os.path.exists(out_path) else None + return time_ms, output, None + return None, None, "No TIME output" + except sp.TimeoutExpired: + return None, None, "TIMEOUT" + except Exception as e: + return None, None, str(e) + + +def parse_problems(spec: str) -> List[FmhaProblem]: + """Parse problem specs: 'batch,nhead,seqlen,hdim;...'""" + problems = [] + for part in spec.split(";"): + vals = [int(x) for x in part.split(",")] + if len(vals) == 4: + b, h, s, d = vals + problems.append( + FmhaProblem( + batch=b, + nhead_q=h, + nhead_k=h, + seqlen_q=s, + seqlen_k=s, + hdim_q=d, + hdim_v=d, + ) + ) + elif len(vals) == 6: + b, hq, hk, sq, sk, d = vals + problems.append( + FmhaProblem( + batch=b, + nhead_q=hq, + nhead_k=hk, + seqlen_q=sq, + seqlen_k=sk, + hdim_q=d, + hdim_v=d, + ) + ) + return problems + + +def main(): + parser = argparse.ArgumentParser(description="FMHA Tile Engine Benchmark") + parser.add_argument( + "configs", nargs="*", help="Sweep config JSON(s) (optional for exhaustive)" + ) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument( + "--workers", type=int, default=os.cpu_count() or 8, help="Parallel JIT workers" + ) + parser.add_argument( + "--problems", + default="2,8,1024,128", + help="Problem sizes: batch,nhead,seqlen,hdim", + ) + + parser.add_argument( + "--no-verify", action="store_true", help="Skip CPU reference verification" + ) + parser.add_argument( + "--best", action="store_true", help="Show best kernel per problem" + ) + parser.add_argument( + "--csv", + type=str, + default=None, + help="CSV output path (default: /results.csv). Use --no-csv to disable.", + ) + parser.add_argument("--no-csv", action="store_true", help="Disable CSV output") + parser.add_argument("--json", type=str, default=None) + parser.add_argument( + "--log", + type=str, + default=None, + help="Path to detailed log file (compilation status, failures, timings)", + ) + parser.add_argument( + "--build-dir", + type=str, + default=str(Path(__file__).resolve().parent / "build"), + help="JIT build output directory", + ) + parser.add_argument("--clean", action="store_true") + parser.add_argument("--compile-only", action="store_true") + parser.add_argument( + "--filter", + dest="filter_expr", + default="", + help='Python expr per config, e.g. "c.hdim_q == 128"', + ) + parser.add_argument( + "--filter-file", default="", help="Path to .py with filter_config(c) -> bool" + ) + parser.add_argument( + "--tiles", + choices=["rules", "exhaustive"], + default="rules", + help="Tile enumeration mode: 'rules' (default) uses constraint-based generation; " + "'exhaustive' brute-forces ALL compilable tiles (like the oracle)", + ) + parser.add_argument( + "--num-splits", + default="1,2,4,8", + help="Comma-separated num_splits values to sweep for splitkv (default: 1,2,4,8)", + ) + parser.add_argument( + "--isolate", + action="store_true", + help="Run each kernel in a subprocess to survive GPU faults (slower but fault-tolerant)", + ) + parser.add_argument( + "--gpus", + type=str, + default=None, + help="Comma-separated GPU IDs to use for parallel benchmarking (e.g. '0,1,2,3'). " + "Implies --isolate. Each GPU runs one kernel at a time.", + ) + args = parser.parse_args() + + # --gpus implies --isolate + if args.gpus: + args.isolate = True + gpu_ids = [int(x) for x in args.gpus.split(",")] if args.gpus else [0] + + problems = parse_problems(args.problems) + num_splits_list = [int(x) for x in args.num_splits.split(",")] + build_dir = Path(args.build_dir).resolve() + + if args.clean and build_dir.exists(): + print(f" Cleaning {build_dir} ...") + shutil.rmtree(build_dir) + + build_dir.mkdir(parents=True, exist_ok=True) + + # Phase 0: Expand configs + all_configs = [] + restrict_hdims = sorted({(p.hdim_q, p.hdim_v) for p in problems}) + if args.tiles == "exhaustive": + # Exhaustive mode: all tiles (no constraint filter) × full feature cross-product. + # JSON config is optional — if provided, its trait_config scopes the sweep. + cfg_path = args.configs[0] if args.configs else None + all_configs = expand_sweep( + cfg_path, + args.arch, + 0, + mode="exhaustive", + restrict_hdims=restrict_hdims, + ) + print( + f" Exhaustive: {len(all_configs)} total combos (all tiles × all features)" + ) + else: + if not args.configs: + parser.error( + "Config JSON(s) required for rules mode. Use --tiles exhaustive to run without." + ) + for cfg_path in args.configs: + configs = expand_sweep( + cfg_path, + args.arch, + 0, + mode="rules", + restrict_hdims=restrict_hdims, + ) + all_configs.extend(configs) + print(f" {cfg_path}: {len(configs)} kernel configs") + + if args.filter_expr or args.filter_file: + before = len(all_configs) + all_configs = apply_filter(all_configs, args.filter_expr, args.filter_file) + print(f" Filter: {before} -> {len(all_configs)} configs") + + # Remove standalone combine configs -- they are auto-paired during JIT + all_configs = [c for c in all_configs if c.family != "fwd_splitkv_combine"] + + print(f"\n{'=' * 70}") + print("FMHA Tile Engine Benchmark") + print(f"{'=' * 70}") + print(f" Arch: {args.arch}") + print(f" Kernels: {len(all_configs)}") + print(f" Problems: {len(problems)}") + print(f" Workers: {args.workers}") + print(f" Build: {build_dir}") + + # Phase 1: Pipelined JIT via the dispatcher + print( + f"\n--- Phase 1: JIT compile ({len(all_configs)} kernels," + f" {args.workers} workers) ---" + ) + jit_t0 = time.perf_counter() + + def _progress(stage, done, total): + elapsed = time.perf_counter() - jit_t0 + pct = done * 100 // total + print( + f"\r [{stage}] {done}/{total} ({pct}%) - {elapsed:.0f}s", + end="", + flush=True, + ) + if done == total: + print() + + setups = setup_multiple_fmha_dispatchers( + all_configs, + output_dir=build_dir, + verbose=True, + max_workers=args.workers, + progress_callback=_progress, + ) + + jit_time = time.perf_counter() - jit_t0 + built = sum(1 for s in setups if s.success) + failed = len(all_configs) - built + print(f"\n Built {built}/{len(all_configs)} in {jit_time:.0f}s ({failed} failed)") + + # Load runners for successfully compiled kernels + for setup in setups: + if setup.success and setup.library_path and setup.runner is None: + try: + setup.runner = FmhaRunner.from_library(setup.library_path, args.arch) + except Exception as e: + print(f" Warning: Failed to load runner: {e}") + setup.success = False + + if args.compile_only: + print(f"\n{'=' * 70}") + print(f" Compile-only mode. {built}/{len(all_configs)} kernels compiled.") + if failed > 0: + print("\n Failed kernels:") + for cfg, s in zip(all_configs, setups): + if not s.success: + err = (s.error or "unknown")[:80] + print(f" {cfg.name}: {err}") + if args.tiles == "exhaustive": + # Oracle-style analysis: find tiles missed by rules vs compilable + from fmha.instance_gen import validate_tile, FmhaTileConfig # noqa: E402 + + missed = [] + for cfg, s in zip(all_configs, setups): + if s.success: + tile = FmhaTileConfig( + bm0=cfg.tile_m0, + bn0=cfg.tile_n0, + bk0=cfg.tile_k0, + bn1=cfg.tile_n1, + bk1=cfg.tile_k1, + bk0max=cfg.tile_k0max, + rm0=cfg.wave_m0, + rn0=1, + rk0=1, + rm1=cfg.wave_m1, + rn1=1, + rk1=1, + wm0=cfg.warp_m0, + wn0=cfg.warp_n0, + wk0=cfg.warp_k0, + wm1=cfg.warp_m1, + wn1=cfg.warp_n1, + wk1=cfg.warp_k1, + ) + if not validate_tile( + tile, + args.arch, + cfg.data_type, + cfg.hdim_q, + cfg.hdim_v, + cfg.pipeline, + ): + missed.append(cfg) + if missed: + print( + f"\n MISSED by rules ({len(missed)} tiles compile but rules reject):" + ) + seen = set() + for cfg in missed: + key = (cfg.tile_m0, cfg.tile_n0, cfg.tile_k0) + if key not in seen: + seen.add(key) + print( + f" ({cfg.tile_m0:>3}, {cfg.tile_n0:>3}, {cfg.tile_k0:>3})" + ) + else: + print( + "\n Rules are COMPLETE — all compilable tiles are generated by rules." + ) + print(f"{'=' * 70}") + return + + # Phase 2: Benchmark + print(f"\n--- Phase 2: Benchmark ({built} kernels x {len(problems)} problems) ---") + + dtype_map = { + "fp16": np.float16, + "bf16": np.float32, + "fp32": np.float32, + "fp8": np.float16, + "fp8bf16": np.float16, + "fp8fp32": np.float16, + "bf8": np.float16, + "mxfp8": np.float16, + "mxfp4": np.float16, + } + # Tolerance per dtype: (atol, rtol) + _DTYPE_TOL = { + "fp16": (1e-3, 1e-3), + "bf16": (1e-2, 1e-2), + "fp32": (1e-5, 1e-5), + "fp8": (16.0, 0.0), + "fp8bf16": (16.0, 0.0), + "fp8fp32": (16.0, 0.0), + "bf8": (16.0, 0.0), + "mxfp8": (16.0, 0.0), + "mxfp4": (32.0, 0.0), + } + np.random.seed(42) + all_results = [] + bench_t0 = time.perf_counter() + + for prob_idx, prob in enumerate(problems): + first_dtype = all_configs[0].data_type if all_configs else "fp16" + first_mask = all_configs[0].mask if all_configs else "no" + np_dtype = dtype_map.get(first_dtype, np.float16) + dtype_tol = _DTYPE_TOL.get(first_dtype, (1e-2, 1e-2)) + # Use uniform [0, 1] like CK example (default 'uf' mode) -- produces + # peaked softmax distributions that actually test kernel correctness. + # randn*0.1 makes softmax nearly uniform for large hdim, hiding bugs. + Q = np.random.uniform(0, 1, prob.q_shape()).astype(np_dtype) + K = np.random.uniform(0, 1, prob.k_shape()).astype(np_dtype) + V = np.random.uniform(0, 1, prob.v_shape()).astype(np_dtype) + + _MASK_INT = {"no": 0, "top_left": 1, "bottom_right": 2, "generic": 3} + first_mask_int = _MASK_INT.get(first_mask, 0) + + ref = None + if not args.no_verify: + # For bf16: truncate inputs to bf16 precision before computing reference, + # so reference sees the SAME data the kernel sees (after bf16 encoding). + if first_dtype == "bf16": + from fmha_utils import _float32_to_bf16, _bf16_to_float32 + + Q_ref = _bf16_to_float32(_float32_to_bf16(Q.astype(np.float32))) + K_ref = _bf16_to_float32(_float32_to_bf16(K.astype(np.float32))) + V_ref = _bf16_to_float32(_float32_to_bf16(V.astype(np.float32))) + else: + Q_ref = Q.astype(np.float32) + K_ref = K.astype(np.float32) + V_ref = V.astype(np.float32) + ref = cpu_attention_fwd( + Q_ref, + K_ref, + V_ref, + prob.scale, + mask_type=first_mask_int, + ) + + h_str = ( + f"H={prob.nhead_q}" + if prob.nhead_q == prob.nhead_k + else f"Hq={prob.nhead_q} Hk={prob.nhead_k}" + ) + s_str = ( + f"S={prob.seqlen_q}" + if prob.seqlen_q == prob.seqlen_k + else f"Sq={prob.seqlen_q} Sk={prob.seqlen_k}" + ) + prob_str = f"B={prob.batch} {h_str} {s_str} D={prob.hdim_q}" + print(f"\n Problem [{prob_idx}]: {prob_str}") + print( + f" {'Kernel':<105} {'Time(ms)':>10} {'TFLOPS':>10}" + f" {'MaxErr':>10} {'Status':>6}" + ) + print(f" {'-' * 145}") + + _BIAS_INT = {"no": 0, "bias": 1, "alibi": 2} + + # Build list of (config, setup, run_kwargs, ns) jobs for benchmarking + bench_jobs = [] + for config, setup in zip(all_configs, setups): + if not setup.success: + continue + if not args.isolate and setup.runner is None: + continue + if config.hdim_q != prob.hdim_q or config.hdim_v != prob.hdim_v: + continue + + mask_int = _MASK_INT.get(config.mask, 0) + is_causal = config.mask in ("top_left", "bottom_right") + is_group = config.mode == "group" + + _FAMILY_TO_API = { + "fwd_splitkv": "splitkv", + "fwd_pagedkv": "pagedkv", + "fwd_appendkv": "appendkv", + } + api_family = _FAMILY_TO_API.get(config.family, config.family) + splits_to_try = num_splits_list if api_family == "splitkv" else [0] + + for ns in splits_to_try: + run_kwargs = dict( + mask_type=mask_int, + bias_type=_BIAS_INT.get(config.bias, 0), + has_lse=int(config.lse), + has_dropout=int(config.dropout), + has_logits=int(config.logits), + has_sink=int(config.sink), + data_type=config.data_type, + is_group_mode=int(is_group), + is_v_rowmajor=int(config.vlayout == "r"), + api_family=api_family, + window_left=-1, + window_right=0 if is_causal else -1, + ) + if api_family == "splitkv": + run_kwargs["num_splits"] = ns + bench_jobs.append( + (config, setup, run_kwargs, ns, api_family, is_causal) + ) + + if args.isolate and len(gpu_ids) > 1: + # ---- Multi-GPU parallel isolated execution ---- + import tempfile + + # Save input data once, shared by all subprocesses + shared_data_dir = tempfile.mkdtemp(prefix="fmha_shared_") + np.save(os.path.join(shared_data_dir, "Q.npy"), Q) + np.save(os.path.join(shared_data_dir, "K.npy"), K) + np.save(os.path.join(shared_data_dir, "V.npy"), V) + + def _run_one(job, gpu_id): + config, setup, run_kwargs, ns, api_family, is_causal = job + # Per-job output dir (unique per subprocess) + job_dir = tempfile.mkdtemp(prefix=f"fmha_gpu{gpu_id}_") + # Symlink shared inputs instead of copying + for fname in ("Q.npy", "K.npy", "V.npy"): + os.symlink( + os.path.join(shared_data_dir, fname), + os.path.join(job_dir, fname), + ) + time_ms, output, err = _run_kernel_isolated( + setup.library_path, args.arch, prob, run_kwargs, job_dir, gpu_id + ) + shutil.rmtree(job_dir, ignore_errors=True) + return (config, time_ms, output, err, ns, api_family, is_causal, gpu_id) + + print(f" Running {len(bench_jobs)} kernels across {len(gpu_ids)} GPUs ...") + for _, result in run_parallel_on_gpus(bench_jobs, gpu_ids, _run_one): + config, time_ms, output, err, ns, api_family, is_causal, gpu_id = result + if err: + splits_tag = f" [ns={ns}]" if api_family == "splitkv" else "" + print( + f" {config.name}{splits_tag:<105} {'---':>10} {'---':>10} {'---':>10} GPU{gpu_id} {err[:15]}" + ) + continue + + r, line = _compute_result( + config, + prob, + time_ms, + output, + ref, + is_causal, + ns, + api_family, + dtype_tol, + gpu_id, + ) + print(line) + all_results.append(r) + + shutil.rmtree(shared_data_dir, ignore_errors=True) + + else: + # ---- Sequential execution (in-process or single-GPU isolated) ---- + for config, setup, run_kwargs, ns, api_family, is_causal in bench_jobs: + time_ms = None + output = None + if args.isolate: + import tempfile + + data_dir = tempfile.mkdtemp(prefix="fmha_run_") + np.save(os.path.join(data_dir, "Q.npy"), Q) + np.save(os.path.join(data_dir, "K.npy"), K) + np.save(os.path.join(data_dir, "V.npy"), V) + time_ms, output, err = _run_kernel_isolated( + setup.library_path, + args.arch, + prob, + run_kwargs, + data_dir, + gpu_ids[0], + ) + shutil.rmtree(data_dir, ignore_errors=True) + if err: + print( + f" {config.name:<105} {'---':>10} {'---':>10} {'---':>10} {err[:20]:>6}" + ) + continue + else: + result = setup.runner.run(Q, K, V, prob, **run_kwargs) + if not result.success: + continue + time_ms = result.time_ms + output = result.output + + r, line = _compute_result( + config, + prob, + time_ms, + output, + ref, + is_causal, + ns, + api_family, + dtype_tol, + ) + print(line) + all_results.append(r) + + bench_time = time.perf_counter() - bench_t0 + + # Cleanup + for setup in setups: + if setup.success and setup.runner: + try: + setup.runner.cleanup() + except Exception: + pass + + # Report + print(f"\n{'=' * 70}") + print(f" JIT: {jit_time:.0f}s ({built} kernels)") + print(f" Benchmark: {bench_time:.1f}s") + print(f" Results: {len(all_results)} measurements") + + if all_results: + from collections import defaultdict + + by_problem = defaultdict(list) + for r in all_results: + key = json.dumps(r["problem"], sort_keys=True) + by_problem[key].append(r) + + print("\n Best kernel per problem:") + for key, results in by_problem.items(): + best = max(results, key=lambda x: x["tflops"]) + prob = json.loads(key) + ns_tag = f" [ns={best['num_splits']}]" if best.get("num_splits") else "" + h_str = ( + f"H={prob['nhead_q']}" + if prob["nhead_q"] == prob["nhead_k"] + else f"Hq={prob['nhead_q']} Hk={prob['nhead_k']}" + ) + s_str = ( + f"S={prob['seqlen_q']}" + if prob["seqlen_q"] == prob["seqlen_k"] + else f"Sq={prob['seqlen_q']} Sk={prob['seqlen_k']}" + ) + print( + f" B={prob['batch']} {h_str}" + f" {s_str} D={prob['hdim_q']}" + f" -> {best['kernel']}{ns_tag}" + f" ({best['tflops']:.2f} TFLOPS, {best['latency_ms']:.3f} ms)" + ) + + # CSV output: default to /results.csv; merge with existing file + # keeping the faster result (higher tflops) for duplicate kernel+problem keys. + _CSV_FIELDS = [ + "kernel", + "dtype", + "pipeline", + "mode", + "mask", + "bias", + "hdim_q", + "hdim_v", + "tile_m0", + "tile_n0", + "tile_k0", + "tile_n1", + "tile_k1", + "tile_k0max", + "warp_m0", + "warp_n0", + "warp_k0", + "block_per_cu", + "num_splits", + "batch", + "nhead_q", + "nhead_k", + "seqlen_q", + "seqlen_k", + "latency_ms", + "tflops", + "max_err", + "status", + ] + csv_path = args.csv if args.csv else str(build_dir / "results.csv") + if not args.no_csv and all_results: + # Build map of new results keyed by (kernel, problem-tuple) + def _csv_key(row): + p = row["problem"] if "problem" in row else row + return ( + row["kernel"], + row.get("num_splits", 0), + p.get("batch"), + p.get("nhead_q"), + p.get("nhead_k"), + p.get("seqlen_q"), + p.get("seqlen_k"), + p.get("hdim_q"), + p.get("hdim_v"), + ) + + # Load existing CSV if present + existing = {} + if os.path.isfile(csv_path): + with open(csv_path, "r", newline="") as f: + reader = csv.DictReader(f) + for row in reader: + # Convert numeric fields back from strings + for k in row: + if k in ("latency_ms", "tflops", "max_err"): + try: + row[k] = float(row[k]) + except (ValueError, TypeError): + pass + elif k in ( + "hdim_q", + "hdim_v", + "tile_m0", + "tile_n0", + "tile_k0", + "tile_n1", + "tile_k1", + "tile_k0max", + "warp_m0", + "warp_n0", + "warp_k0", + "block_per_cu", + "num_splits", + "batch", + "nhead_q", + "nhead_k", + "seqlen_q", + "seqlen_k", + ): + try: + row[k] = int(row[k]) + except (ValueError, TypeError): + pass + key = _csv_key(row) + existing[key] = row + + # Merge new results — keep whichever is faster + for r in all_results: + row = {**r, **r["problem"]} + del row["problem"] + key = _csv_key(r) + prev = existing.get(key) + if prev is None or float(row.get("tflops", 0)) > float( + prev.get("tflops", 0) + ): + existing[key] = row + + # Write merged + sorted CSV + merged = sorted( + existing.values(), key=lambda x: float(x.get("tflops", 0)), reverse=True + ) + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=_CSV_FIELDS, extrasaction="ignore") + writer.writeheader() + for row in merged: + writer.writerow(row) + print(f"\n CSV: {csv_path} ({len(merged)} rows, sorted by tflops)") + + if args.json: + report = { + "metadata": { + "arch": args.arch, + "jit_time_s": jit_time, + "bench_time_s": bench_time, + "num_kernels": len(all_configs), + "num_built": built, + "num_problems": len(problems), + }, + "results": all_results, + } + with open(args.json, "w") as f: + json.dump(report, f, indent=2) + print(f" JSON: {args.json}") + + if args.log: + from datetime import datetime + + with open(args.log, "w") as lf: + lf.write(f"FMHA Benchmark Log - {datetime.now().isoformat()}\n") + lf.write(f"{'=' * 80}\n\n") + lf.write(f"Command: {' '.join(sys.argv)}\n") + lf.write(f"Arch: {args.arch}\n") + lf.write(f"Tiles mode: {args.tiles}\n") + lf.write(f"Workers: {args.workers}\n") + lf.write(f"Build dir: {build_dir}\n") + lf.write(f"Total configs: {len(all_configs)}\n") + lf.write(f"Built: {built}\n") + lf.write(f"Failed: {failed}\n") + lf.write(f"JIT time: {jit_time:.1f}s\n") + lf.write(f"Bench time: {bench_time:.1f}s\n") + lf.write(f"Problems: {[str(p) for p in problems]}\n\n") + + # All configs attempted + lf.write(f"{'=' * 80}\n") + lf.write(f"ALL CONFIGS ({len(all_configs)})\n") + lf.write(f"{'=' * 80}\n\n") + for i, (cfg, setup) in enumerate(zip(all_configs, setups)): + status = "OK" if setup.success else "FAILED" + lf.write(f"[{i:4d}] {status:6s} {cfg.name}\n") + lf.write( + f" tile=({cfg.tile_m0},{cfg.tile_n0},{cfg.tile_k0},{cfg.tile_n1},{cfg.tile_k1},{cfg.tile_k0max})" + f" warp=({cfg.warp_m0},{cfg.warp_n0},{cfg.warp_k0})" + f" bpc={cfg.block_per_cu}\n" + ) + if not setup.success and setup.error: + lf.write(f" error: {setup.error}\n") + lf.write("\n") + + # Failed configs summary + lf.write(f"\n{'=' * 80}\n") + lf.write(f"FAILED CONFIGS ({failed})\n") + lf.write(f"{'=' * 80}\n\n") + for cfg, setup in zip(all_configs, setups): + if not setup.success: + lf.write(f" {cfg.name}\n") + if setup.error: + lf.write(f" {setup.error}\n") + + # Benchmark results + if all_results: + lf.write(f"\n{'=' * 80}\n") + lf.write(f"BENCHMARK RESULTS ({len(all_results)} measurements)\n") + lf.write(f"{'=' * 80}\n\n") + sorted_results = sorted(all_results, key=lambda x: -x["tflops"]) + for r in sorted_results: + p = r["problem"] + lf.write( + f" {r['tflops']:8.2f} TFLOPS {r['latency_ms']:8.3f} ms" + f" B={p['batch']} H={p['nhead_q']} S={p['seqlen_q']} D={p['hdim_q']}" + f" {r['kernel']}\n" + ) + + print(f" Log: {args.log}") + + print(f"{'=' * 70}") + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/fmha/fmha_full_benchmark.py b/tile_engine/ops/fmha/fmha_full_benchmark.py new file mode 100644 index 0000000000..b6f6b2401c --- /dev/null +++ b/tile_engine/ops/fmha/fmha_full_benchmark.py @@ -0,0 +1,689 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Full FMHA benchmark sweep. + +JIT-compiles FMHA kernels, then for EACH test shape finds all matching +kernels and benchmarks them, streaming results incrementally to CSV/JSON. + +Results are printed live per-shape with the best kernel highlighted. +TFLOPS and latency come directly from CK's HIP event timing. + +Usage: + # Full sweep + python fmha_full_benchmark.py --workers 256 + + # Quick end-to-end test + python fmha_full_benchmark.py --category smoke --variant fwd --max-kernels 10 --workers 4 + + # Filter to h128 fp16 + python fmha_full_benchmark.py --filter "c.hdim_q == 128 and c.data_type == 'fp16'" +""" + +import argparse +import csv +import itertools +import json +import os +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +import yaml +import numpy as np + +_THIS_DIR = Path(__file__).resolve().parent +_DISPATCHER_ROOT = _THIS_DIR.parents[2] / "dispatcher" +sys.path.insert(0, str(_DISPATCHER_ROOT / "python")) +sys.path.insert(0, str(_DISPATCHER_ROOT / "codegen")) +sys.path.insert(0, str(_THIS_DIR)) + +from fmha_utils import ( # noqa: E402 + detect_gpu_arch, + setup_multiple_fmha_dispatchers, +) +from fmha.instance_gen import expand_sweep, apply_filter # noqa: E402 + +YAML_PATH = _THIS_DIR / "ck_fmha_testing_matrix.yaml" + +VARIANT_CONFIGS = { + "fwd": "configs/receipt0_fwd.json", + "splitkv": "configs/splitkv.json", + "pagedkv": "configs/pagedkv.json", + "appendkv": "configs/appendkv.json", + "batch_prefill": "configs/batch_prefill.json", + "bwd": "configs/bwd.json", +} + +# Variant -> YAML section mapping. KV-cache variants use forward_tests shapes. +VARIANT_YAML_SECTIONS = { + "fwd": ["forward_tests"], + "splitkv": ["forward_tests"], + "pagedkv": ["forward_tests"], + "appendkv": ["forward_tests"], + "batch_prefill": ["forward_tests"], + "bwd": ["backward_tests"], +} + +DTYPE_CK = {"fp16": "fp16", "bf16": "bf16", "fp8bf16": "fp8bf16", "fp8fp32": "fp8fp32"} +DTYPE_NP = { + "fp16": np.float16, + "bf16": np.float16, + "fp32": np.float32, + "fp8bf16": np.float16, + "fp8fp32": np.float16, +} +ELEM_BYTES = {"fp16": 2, "bf16": 2, "fp32": 4, "fp8bf16": 1, "fp8fp32": 1} + +MASK_INT = {"no": 0, "top_left": 1, "generic": 3} +BIAS_INT = {"no": 0, "bias": 1, "alibi": 2} +KV_LAYOUT_INT = {"vectorized": 0, "linear": 1} +KV_LOOKUP_INT = {"vllm": 0, "sglang": 1} + + +@dataclass +class TestShape: + name: str + category: str + variant: str + batch: int + seqlen_q: int + seqlen_k: int + nhead_q: int + nhead_k: int + hdim_q: int + hdim_v: int + dtype: str + mask: str = "no_mask" + bias: str = "none" + dropout: float = 0.0 + lse: bool = False + + +def parse_yaml( + yaml_path: Path, category: str = "smoke", sections: Optional[List[str]] = None +) -> List[TestShape]: + with open(yaml_path) as f: + data = yaml.safe_load(f) + shapes = [] + cats = ["smoke"] + if category in ("full", "nightly"): + cats.append("full") + if category == "nightly": + cats.append("nightly") + + section_variant_map = [("forward_tests", "fwd"), ("backward_tests", "bwd")] + if sections: + section_variant_map = [(s, v) for s, v in section_variant_map if s in sections] + + for section, variant in section_variant_map: + if section not in data: + continue + for cat in cats: + for test in data[section].get(cat, []): + for combo in itertools.product( + test.get("batch", [1]), + test.get("seqlen_q", [1024]), + test.get("seqlen_k", [1024]), + test.get("nhead_q", [16]), + test.get("nhead_k", [16]), + test.get("hdim_q", [128]), + test.get("hdim_v", [128]), + test.get("dtype", ["fp16"]), + test.get("mask", ["no_mask"]), + test.get("bias", ["none"]), + test.get("dropout", [0.0]), + test.get("lse", [False]), + ): + b, sq, sk, hq, hk, dq, dv, dt, m, bi, dr, ls = combo + shapes.append( + TestShape( + test["name"], + cat, + variant, + b, + sq, + sk, + hq, + hk, + dq, + dv, + dt, + mask=m, + bias=bi, + dropout=dr, + lse=ls, + ) + ) + return shapes + + +def bandwidth_gb_s(shape: TestShape, latency_ms: float) -> float: + if latency_ms <= 0: + return 0.0 + eb = ELEM_BYTES.get(shape.dtype, 2) + total = ( + shape.batch + * ( + shape.nhead_q * shape.seqlen_q * shape.hdim_q + + shape.nhead_k * shape.seqlen_k * shape.hdim_q + + shape.nhead_k * shape.seqlen_k * shape.hdim_v + + shape.nhead_q * shape.seqlen_q * shape.hdim_v + ) + * eb + ) + return total / (latency_ms * 1e6) + + +FAMILY_TO_API = { + "fwd": "fwd", + "fwd_splitkv": "splitkv", + "fwd_splitkv_combine": "splitkv", + "fwd_pagedkv": "pagedkv", + "fwd_appendkv": "appendkv", + "batch_prefill": "batch_prefill", + "bwd_dot_do_o": "bwd", + "bwd_dq_dk_dv": "bwd", + "bwd_convert_dq": "bwd", +} + + +def _config_to_serializable(config, so_path: str) -> dict: + """Convert FmhaKernelConfig + so_path to a picklable dict for subprocess.""" + return { + "so_path": so_path, + "api_family": FAMILY_TO_API.get(config.family, "fwd"), + "data_type": config.data_type, + "kernel": config.name, + "family": config.family, + "mode": config.mode, + "pipeline": config.pipeline, + "tile_m0": config.tile_m0, + "tile_n0": config.tile_n0, + "tile_k0": config.tile_k0, + "tile_n1": config.tile_n1, + "tile_k1": config.tile_k1, + "tile_k0max": config.tile_k0max, + "pad_s": config.pad_s, + "pad_sk": config.pad_sk, + "pad_d": config.pad_d, + "pad_dv": config.pad_dv, + "mask": config.mask, + "bias": config.bias, + "lse": config.lse, + "dropout": config.dropout, + "logits": config.logits, + "sink": config.sink, + "skip": config.skip_min_seqlen_q, + "qscale": config.qscale, + "paged_kv": config.paged_kv, + "rope": config.rope, + "deterministic": config.deterministic, + "dbias": config.dbias, + "mask_int": MASK_INT.get(config.mask, 0), + "bias_int": BIAS_INT.get(config.bias, 0), + "has_lse": int(config.lse), + "has_dropout": int(config.dropout not in (False, 0, "no", "False")), + "has_logits": int(config.logits), + "has_sink": int(config.sink), + "has_skip": int(config.skip_min_seqlen_q), + "has_dbias": int(getattr(config, "dbias", False)), + "is_store_randval": int(getattr(config, "store_randval", False)), + "page_size": getattr(config, "page_size", 16), + "kv_layout": KV_LAYOUT_INT.get( + getattr(config, "kv_memory_layout", "vectorized"), 0 + ), + "kv_lookup": KV_LOOKUP_INT.get(getattr(config, "kv_lookup_table", "sglang"), 1), + } + + +def _shape_to_dict(shape: TestShape) -> dict: + return { + "name": shape.name, + "category": shape.category, + "variant": shape.variant, + "batch": shape.batch, + "seqlen_q": shape.seqlen_q, + "seqlen_k": shape.seqlen_k, + "nhead_q": shape.nhead_q, + "nhead_k": shape.nhead_k, + "hdim_q": shape.hdim_q, + "hdim_v": shape.hdim_v, + "dtype": shape.dtype, + "mask": shape.mask, + "bias": shape.bias, + "dropout": shape.dropout, + "lse": shape.lse, + } + + +def main(): + p = argparse.ArgumentParser(description="Full FMHA Benchmark Sweep") + p.add_argument("--arch", default=detect_gpu_arch()) + p.add_argument("--category", default="smoke", choices=["smoke", "full", "nightly"]) + p.add_argument("--variant", default="all") + p.add_argument("--workers", type=int, default=8) + p.add_argument("--build-dir", default="/tmp/fmha_full_bench") + p.add_argument("--filter", dest="filter_expr", default="") + p.add_argument("--filter-file", default="") + p.add_argument("--csv", default="fmha_sweep_results.csv") + p.add_argument("--json", default="fmha_sweep_results.json") + p.add_argument("--compile-only", action="store_true") + p.add_argument("--max-kernels", type=int, default=0) + p.add_argument( + "--shape-timeout", + type=int, + default=600, + help="Per-shape timeout in seconds (0=none)", + ) + args = p.parse_args() + + build_dir = Path(args.build_dir) + build_dir.mkdir(parents=True, exist_ok=True) + + variants = list(VARIANT_CONFIGS.keys()) if args.variant == "all" else [args.variant] + + # ---- Phase 1: Parse shapes ---- + print(f"\n{'=' * 80}") + print("Phase 1: Parse test shapes") + print(f"{'=' * 80}") + + all_shapes: List[TestShape] = [] + for variant in variants: + sections = VARIANT_YAML_SECTIONS.get(variant, ["forward_tests"]) + vshapes = parse_yaml(YAML_PATH, args.category, sections=sections) + for s in vshapes: + s.variant = variant + all_shapes.extend(vshapes) + + print(f" Category: {args.category}") + print(f" Variants: {variants}") + print(f" Total shapes: {len(all_shapes)}") + + # ---- Phase 2: Compile ---- + print(f"\n{'=' * 80}") + print("Phase 2: Compile kernels") + print(f"{'=' * 80}") + + # kernel_index: (hdim_q, hdim_v, dtype, variant) -> list of (so_path, cfg_dict) + kernel_index: Dict[tuple, List[tuple]] = {} + + from concurrent.futures import ProcessPoolExecutor as _PPE + + _compile_pool = _PPE(max_workers=args.workers) + BATCH_SIZE = 200 + + for variant in variants: + cfg_path = str(_THIS_DIR / VARIANT_CONFIGS[variant]) + if not Path(cfg_path).exists(): + continue + configs = expand_sweep(cfg_path, args.arch) + if args.filter_expr or args.filter_file: + configs = apply_filter(configs, args.filter_expr, args.filter_file) + if args.max_kernels > 0: + configs = configs[: args.max_kernels] + if not configs: + continue + + n_batches = (len(configs) + BATCH_SIZE - 1) // BATCH_SIZE + print( + f"\n {variant}: {len(configs)} configs, {args.workers} workers, {n_batches} batches..." + ) + t0 = time.perf_counter() + setups = [] + total_ok = 0 + for bi in range(n_batches): + batch_cfgs = configs[bi * BATCH_SIZE : (bi + 1) * BATCH_SIZE] + batch_setups = setup_multiple_fmha_dispatchers( + batch_cfgs, + output_dir=build_dir, + max_workers=args.workers, + executor=_compile_pool, + ) + batch_ok = sum(1 for s in batch_setups if s.success) + batch_n = len(batch_cfgs) + total_ok += batch_ok + setups.extend(zip(batch_cfgs, batch_setups)) + del batch_setups, batch_cfgs + print( + f" Batch {bi + 1}/{n_batches}: {batch_ok}/{batch_n} " + f"(total {total_ok}, {time.perf_counter() - t0:.0f}s)", + flush=True, + ) + ok = total_ok + print(f" Built {ok}/{len(configs)} in {time.perf_counter() - t0:.0f}s") + + for config, setup in setups: + if not setup.success: + continue + so_path = getattr(setup, "library_path", "") or "" + if not so_path: + candidate = build_dir / f"libdispatcher_fmha_{config.name}.so" + if candidate.exists(): + so_path = str(candidate) + if not so_path: + continue + cfg_dict = _config_to_serializable(config, so_path) + key = (config.hdim_q, config.hdim_v, config.data_type, variant, config.mode) + kernel_index.setdefault(key, []).append((so_path, cfg_dict)) + + _compile_pool.shutdown(wait=True) + del _compile_pool + + total_built = sum(len(v) for v in kernel_index.values()) + print(f"\n Total compiled: {total_built}") + print(f" Unique (hdim,dtype,variant) groups: {len(kernel_index)}") + + if args.compile_only: + print(f"\n Compile-only. {total_built} kernels ready.") + return + + # ---- Phase 3: Benchmark (serial, one subprocess per kernel) ---- + print(f"\n{'=' * 80}") + print("Phase 3: Benchmark (one subprocess per kernel, serial GPU)") + print(f"{'=' * 80}") + + csv_path = Path(args.csv) if os.path.isabs(args.csv) else _THIS_DIR / args.csv + csv_fields = [ + "problem_name", + "batch", + "seqlen_q", + "seqlen_k", + "nhead_q", + "nhead_k", + "hdim_q", + "hdim_v", + "dtype", + "kernel", + "family", + "mode", + "pipeline", + "tile_m0", + "tile_n0", + "tile_k0", + "tile_n1", + "tile_k1", + "tile_k0max", + "pad_s", + "pad_sk", + "pad_d", + "pad_dv", + "mask", + "bias", + "lse", + "dropout", + "logits", + "sink", + "skip", + "qscale", + "paged_kv", + "rope", + "deterministic", + "dbias", + "latency_ms", + "tflops", + "bandwidth_gb_s", + ] + + # Resume: load already-completed measurements + completed: set = set() + if csv_path.exists() and csv_path.stat().st_size > 0: + with open(csv_path, newline="") as f: + for row in csv.DictReader(f): + completed.add( + ( + row.get("kernel", ""), + row.get("problem_name", ""), + str(row.get("batch", "")), + str(row.get("seqlen_q", "")), + row.get("dtype", ""), + ) + ) + csv_file = open(csv_path, "a", newline="") + writer = csv.DictWriter(csv_file, fieldnames=csv_fields) + print(f" Resuming: {len(completed)} measurements already in CSV") + else: + csv_file = open(csv_path, "w", newline="") + writer = csv.DictWriter(csv_file, fieldnames=csv_fields) + writer.writeheader() + + # Pre-filter: match shapes to kernels by (hdim, dtype, variant, mode). + # YAML shapes are batch-mode only. Group-mode kernels need seqstart arrays + # which batch shapes don't provide -- they all GPU fault. + runnable = [] + for shape in all_shapes: + ck_dtype = DTYPE_CK.get(shape.dtype, shape.dtype) + key = (shape.hdim_q, shape.hdim_v, ck_dtype, shape.variant, "batch") + entries = kernel_index.get(key, []) + if entries: + runnable.append((shape, entries)) + + # Flatten to work items, skip already-completed + def _resume_key(cfg, shape): + return ( + cfg.get("kernel", ""), + shape.name, + str(shape.batch), + str(shape.seqlen_q), + shape.dtype, + ) + + work_items = [] + skipped = 0 + for shape, kernel_entries in runnable: + for so_path, cfg in kernel_entries: + if _resume_key(cfg, shape) in completed: + skipped += 1 + else: + work_items.append((shape, so_path, cfg)) + + total_work = len(work_items) + skipped + total_measurements = len(completed) + total_gpu_faults = 0 + bench_t0 = time.perf_counter() + BENCH_BATCH = 50 + + worker_path = _THIS_DIR / "run_one_kernel.py" + worker_env = os.environ.copy() + worker_env["FMHA_PYPATH_1"] = str(_DISPATCHER_ROOT / "python") + worker_env["FMHA_PYPATH_2"] = str(_DISPATCHER_ROOT / "codegen") + + CFG_KEYS = [ + "kernel", + "family", + "mode", + "pipeline", + "tile_m0", + "tile_n0", + "tile_k0", + "tile_n1", + "tile_k1", + "tile_k0max", + "pad_s", + "pad_sk", + "pad_d", + "pad_dv", + "mask", + "bias", + "lse", + "dropout", + "logits", + "sink", + "skip", + "qscale", + "paged_kv", + "rope", + "deterministic", + "dbias", + ] + + print(f" Runnable shapes: {len(runnable)}") + print(f" Total kernel x shape pairs: {total_work}") + print(f" Already completed: {skipped}") + print(f" Pending: {len(work_items)}") + print(f" Batch size: {BENCH_BATCH} (retry individually on fault)") + print() + + def _run_subprocess(payload_bytes, timeout=10): + proc = subprocess.Popen( + [sys.executable, str(worker_path)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env=worker_env, + ) + timed_out = False + stdout_bytes = b"" + try: + stdout_bytes, _ = proc.communicate(input=payload_bytes, timeout=timeout) + except subprocess.TimeoutExpired: + proc.kill() + proc.communicate() + timed_out = True + finally: + pid = proc.pid + if proc.poll() is None: + proc.kill() + proc.wait() + for pipe in [proc.stdin, proc.stdout, proc.stderr]: + if pipe and not pipe.closed: + pipe.close() + gpucore = _THIS_DIR / f"gpucore.{pid}" + if gpucore.exists(): + gpucore.unlink(missing_ok=True) + rc = -1 if timed_out else proc.returncode + return stdout_bytes, rc + + def _record_result(r, shape, cfg, shape_dict): + nonlocal total_measurements + lat_ms, tflops = r["ms"], r["tflops"] + bw = bandwidth_gb_s(shape, lat_ms) + row = { + "problem_name": shape.name, + "batch": shape.batch, + "seqlen_q": shape.seqlen_q, + "seqlen_k": shape.seqlen_k, + "nhead_q": shape.nhead_q, + "nhead_k": shape.nhead_k, + "hdim_q": shape.hdim_q, + "hdim_v": shape.hdim_v, + "dtype": shape.dtype, + } + for k in CFG_KEYS: + row[k] = cfg.get(k, "") + row["latency_ms"] = round(lat_ms, 4) + row["tflops"] = round(tflops, 2) + row["bandwidth_gb_s"] = round(bw, 2) + writer.writerow(row) + csv_file.flush() + total_measurements += 1 + return tflops, lat_ms + + # Process in batches + n_batches = (len(work_items) + BENCH_BATCH - 1) // BENCH_BATCH + processed = 0 + for bi in range(n_batches): + batch = work_items[bi * BENCH_BATCH : (bi + 1) * BENCH_BATCH] + + items = [] + for shape, so_path, cfg in batch: + cfg["so_path"] = so_path + items.append( + {"so_path": so_path, "shape": _shape_to_dict(shape), "cfg": cfg} + ) + + batch_timeout = len(batch) * 2 + 5 + payload = json.dumps({"items": items}).encode() + stdout_bytes, rc = _run_subprocess(payload, timeout=batch_timeout) + + if rc == 0: + batch_ok = 0 + for line in stdout_bytes.decode().strip().split("\n"): + if not line: + continue + try: + r = json.loads(line) + except (json.JSONDecodeError, ValueError): + continue + idx = r.get("idx", -1) + if not r.get("ok") or idx < 0 or idx >= len(batch): + continue + shape, so_path, cfg = batch[idx] + _record_result(r, shape, cfg, items[idx]["shape"]) + batch_ok += 1 + processed += len(batch) + print( + f" [batch {bi + 1}/{n_batches}] {batch_ok}/{len(batch)} ok " + f"({processed}/{len(work_items)} done, {total_measurements} total)", + flush=True, + ) + else: + # Collect partial results flushed before the fault + partial_done = set() + for line in stdout_bytes.decode().strip().split("\n"): + if not line: + continue + try: + r = json.loads(line) + except (json.JSONDecodeError, ValueError): + continue + idx = r.get("idx", -1) + if r.get("ok") and 0 <= idx < len(batch): + shape, so_path, cfg = batch[idx] + _record_result(r, shape, cfg, items[idx]["shape"]) + partial_done.add(idx) + + # Retry the rest one by one + retry = [(i, b) for i, b in enumerate(batch) if i not in partial_done] + print( + f" [batch {bi + 1}/{n_batches}] FAULT after {len(partial_done)}/{len(batch)} ok, " + f"retrying {len(retry)} individually...", + flush=True, + ) + for idx, (shape, so_path, cfg) in retry: + cfg["so_path"] = so_path + p = json.dumps( + {"so_path": so_path, "shape": items[idx]["shape"], "cfg": cfg} + ).encode() + out, single_rc = _run_subprocess(p, timeout=10) + if single_rc != 0: + total_gpu_faults += 1 + continue + try: + r = json.loads(out.decode().strip().split("\n")[0]) + except (json.JSONDecodeError, ValueError): + continue + if r.get("ok"): + tflops, lat_ms = _record_result(r, shape, cfg, items[idx]["shape"]) + print( + f" {tflops:>7.1f} TFLOPS {lat_ms:.4f}ms " + f"{cfg.get('kernel', '?')[:45]} | {shape.name}", + flush=True, + ) + processed += len(batch) + print(f" ({processed}/{len(work_items)} done)", flush=True) + + csv_file.close() + bench_time = time.perf_counter() - bench_t0 + + # ---- Phase 4: Summary ---- + print(f"\n{'=' * 80}") + print("Results") + print(f"{'=' * 80}") + print(f" Total work items: {total_work}") + print(f" Skipped (resumed): {skipped}") + print(f" Measurements: {total_measurements}") + print(f" GPU faults: {total_gpu_faults}") + print(f" Benchmark time: {bench_time:.1f}s") + print(f" CSV: {csv_path}") + print(f"{'=' * 80}") + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/fmha/run_full_sweep.py b/tile_engine/ops/fmha/run_full_sweep.py new file mode 100644 index 0000000000..d443d966e5 --- /dev/null +++ b/tile_engine/ops/fmha/run_full_sweep.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Full FMHA benchmark sweep, organized by variant and dtype. + +Compiles all kernels per variant (shared build dir for caching), +benchmarks against all smoke shapes, then splits results into: + + / + fwd/fp16/results.csv + fwd/bf16/results.csv + splitkv/fp16/results.csv + ... + bwd_dot_do_o/fp16/results.csv + bwd_dq_dk_dv/fp16/results.csv + bwd_convert_dq/fp16/results.csv + +Usage: + python run_full_sweep.py --workers 256 + python run_full_sweep.py --workers 256 --category full --output /tmp/fmha_sweep +""" + +import argparse +import csv +import os +import subprocess +import sys +import time +from collections import defaultdict +from pathlib import Path + +_THIS_DIR = Path(__file__).resolve().parent + +VARIANTS = ["fwd", "splitkv", "pagedkv", "appendkv", "batch_prefill", "bwd"] + +BWD_FAMILIES = ["bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"] + + +def run_variant(variant, category, workers, build_dir, raw_csv, shape_timeout=600): + """Run fmha_full_benchmark.py for one variant, return path to raw CSV.""" + cmd = [ + sys.executable, + str(_THIS_DIR / "fmha_full_benchmark.py"), + "--category", + category, + "--variant", + variant, + "--workers", + str(workers), + "--build-dir", + str(build_dir), + "--csv", + str(raw_csv), + "--json", + str(raw_csv.with_suffix(".json")), + "--shape-timeout", + str(shape_timeout), + ] + print(f"\n{'=' * 80}") + print(f" Variant: {variant}") + print(f" Command: {' '.join(cmd)}") + print(f"{'=' * 80}", flush=True) + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + proc = subprocess.run(cmd, env=env) + return proc.returncode + + +def split_csv(raw_csv, output_dir): + """Split a raw CSV into per-family per-dtype subdirectories.""" + if not raw_csv.exists(): + return {} + + counts = defaultdict(int) + writers = {} + files = {} + + with open(raw_csv, newline="") as f: + reader = csv.DictReader(f) + fieldnames = reader.fieldnames + + for row in reader: + family = row.get("family", "unknown") + dtype = row.get("dtype", "unknown") + key = (family, dtype) + + if key not in writers: + d = output_dir / family / dtype + d.mkdir(parents=True, exist_ok=True) + fh = open(d / "results.csv", "w", newline="") + w = csv.DictWriter(fh, fieldnames=fieldnames) + w.writeheader() + writers[key] = w + files[key] = fh + + writers[key].writerow(row) + counts[key] += 1 + + for fh in files.values(): + fh.close() + + return counts + + +def main(): + p = argparse.ArgumentParser( + description="Full FMHA Sweep (organized by variant/dtype)" + ) + p.add_argument("--workers", type=int, default=256) + p.add_argument("--category", default="smoke", choices=["smoke", "full", "nightly"]) + p.add_argument("--output", default="/tmp/fmha_sweep") + p.add_argument("--build-dir", default="/tmp/fmha_sweep_build") + p.add_argument( + "--variants", + nargs="+", + default=VARIANTS, + choices=VARIANTS, + help="Which variants to run", + ) + p.add_argument( + "--shape-timeout", type=int, default=600, help="Per-shape timeout in seconds" + ) + args = p.parse_args() + + output_dir = Path(args.output) + build_dir = Path(args.build_dir) + output_dir.mkdir(parents=True, exist_ok=True) + build_dir.mkdir(parents=True, exist_ok=True) + + t0 = time.perf_counter() + grand_total = defaultdict(int) + + for variant in args.variants: + raw_csv = output_dir / f"_raw_{variant}.csv" + rc = run_variant( + variant, args.category, args.workers, build_dir, raw_csv, args.shape_timeout + ) + if rc != 0: + print(f"\n WARNING: {variant} exited with code {rc}", flush=True) + + counts = split_csv(raw_csv, output_dir) + for key, n in counts.items(): + grand_total[key] += n + family, dtype = key + print(f" {family}/{dtype}: {n} measurements") + + elapsed = time.perf_counter() - t0 + + print(f"\n{'=' * 80}") + print("SWEEP COMPLETE") + print(f"{'=' * 80}") + print(f" Total time: {elapsed / 60:.1f} min") + print(f" Output dir: {output_dir}") + print() + print(f" {'Family':<25} {'Dtype':<10} {'Measurements':>12}") + print(f" {'-' * 25} {'-' * 10} {'-' * 12}") + total = 0 + for (family, dtype), n in sorted(grand_total.items()): + print(f" {family:<25} {dtype:<10} {n:>12,}") + total += n + print(f" {'-' * 25} {'-' * 10} {'-' * 12}") + print(f" {'TOTAL':<25} {'':<10} {total:>12,}") + + print("\n Directory structure:") + for d in sorted(output_dir.rglob("results.csv")): + rel = d.relative_to(output_dir) + print(f" {rel}") + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/fmha/run_one_kernel.py b/tile_engine/ops/fmha/run_one_kernel.py new file mode 100644 index 0000000000..5d4e8fa149 --- /dev/null +++ b/tile_engine/ops/fmha/run_one_kernel.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""Run FMHA kernel(s) on GPU and report timing. + +Single mode: stdin = {"so_path": ..., "shape": {...}, "cfg": {...}} +Batch mode: stdin = {"items": [{"so_path": ..., "shape": {...}, "cfg": {...}}, ...]} + +Each result prints one JSON line to stdout (flushed immediately): + {"idx": 0, "ok": true, "ms": 0.123, "tflops": 456.7} + {"idx": 1, "ok": false} + +Flushing per-line lets the parent recover partial results if a later +kernel causes a GPU fault that kills this process. +""" + +import json +import os +import sys + +import numpy as np + +for p in [os.environ.get("FMHA_PYPATH_1", ""), os.environ.get("FMHA_PYPATH_2", "")]: + if p and p not in sys.path: + sys.path.insert(0, p) + +from fmha_utils import FmhaProblem, FmhaRunner # noqa: E402 + +DTYPE_NP = { + "fp16": np.float16, + "bf16": np.float16, + "fp32": np.float32, + "fp8bf16": np.float16, + "fp8fp32": np.float16, +} + + +def _run_one(idx, so_path, s, cfg): + prob = FmhaProblem( + batch=s["batch"], + nhead_q=s["nhead_q"], + nhead_k=s["nhead_k"], + seqlen_q=s["seqlen_q"], + seqlen_k=s["seqlen_k"], + hdim_q=s["hdim_q"], + hdim_v=s["hdim_v"], + ) + dt = DTYPE_NP.get(s.get("dtype", "fp16"), np.float16) + np.random.seed(42) + q = (np.random.randn(*prob.q_shape()) * 0.1).astype(dt) + k = (np.random.randn(*prob.k_shape()) * 0.1).astype(dt) + v = (np.random.randn(*prob.v_shape()) * 0.1).astype(dt) + + runner = FmhaRunner.from_library(so_path) + api = cfg.get("api_family", "fwd") + + if api == "bwd": + out_buf = ( + np.random.randn(s["batch"], s["nhead_q"], s["seqlen_q"], s["hdim_v"]) * 0.1 + ).astype(dt) + lse = np.random.randn(s["batch"], s["nhead_q"], s["seqlen_q"]).astype( + np.float32 + ) + d_out = (np.random.randn(*out_buf.shape) * 0.1).astype(dt) + result = runner.run_bwd( + q, + k, + v, + out_buf, + lse, + d_out, + prob, + data_type=cfg.get("data_type", "fp16"), + mask_type=cfg.get("mask_int", 0), + bias_type=cfg.get("bias_int", 0), + has_dropout=cfg.get("has_dropout", 0), + has_dbias=cfg.get("has_dbias", 0), + is_deterministic=cfg.get("deterministic", 0), + is_group_mode=cfg.get("mode", "batch") == "group", + is_store_randval=cfg.get("is_store_randval", 0), + tile_n0=cfg.get("tile_n0", 128), + ) + else: + result = runner.run( + q, + k, + v, + prob, + mask_type=cfg.get("mask_int", 0), + bias_type=cfg.get("bias_int", 0), + has_lse=cfg.get("has_lse", 0), + has_dropout=cfg.get("has_dropout", 0), + has_logits=cfg.get("has_logits", 0), + has_sink=cfg.get("has_sink", 0), + has_skip=cfg.get("has_skip", 0), + api_family=api, + data_type=cfg.get("data_type", "fp16"), + page_size=cfg.get("page_size", 16), + kv_layout=cfg.get("kv_layout", 0), + kv_lookup=cfg.get("kv_lookup", 1), + is_group_mode=cfg.get("mode", "batch") == "group", + ) + + if result.success: + print( + json.dumps( + {"idx": idx, "ok": True, "ms": result.time_ms, "tflops": result.tflops} + ), + flush=True, + ) + else: + print(json.dumps({"idx": idx, "ok": False}), flush=True) + + +def main(): + d = json.loads(sys.stdin.buffer.read()) + + if "items" in d: + for i, item in enumerate(d["items"]): + _run_one(i, item["so_path"], item["shape"], item["cfg"]) + else: + _run_one(0, d["cfg"]["so_path"], d["shape"], d["cfg"]) + + +if __name__ == "__main__": + main()