From 96c39b331e2b2150a43b45ce1bcf2bec8ba85a02 Mon Sep 17 00:00:00 2001 From: John Afaganis <86133604+afagaj@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:00:17 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7829 (commit 13af7da) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [ck] Enforce ASCII-only C/C++ sources for hipRTC compatibility (#7829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary CK source files must be compilable via **hipRTC (HIP runtime compilation)**, whose preprocessor does not accept non-ASCII bytes anywhere in a translation unit — **including in comments**. Bytes that are harmless under `hipcc` (em-dashes, smart quotes, multiplication signs, Greek letters, box-drawing glyphs, etc.) cause hipRTC to fail at preprocessing time. These regularly leak in via LLM-assisted authoring or copy/paste from formatted documents and silently break hipRTC paths that are not exercised by the default `hipcc`-based build matrix. This PR (a) cleans every existing violation (53 files) and (b) adds a pre-checkin gate so new violations are rejected before merge. ## File extensions covered Both the cleanup scan and the new Jenkins enforcement stage use the same predicate: ``` *.h *.hpp *.cpp *.h.in *.hpp.in *.cpp.in *.inc *.cl ``` (excluding `*/build/*` and `*/include/rapidjson/*`). This is a strict superset of the existing `Clang Format` stage's predicate — `*.inc` is added so test-fixture include files are also gated. The local pre-commit hook's `c++/inc` type filter covers the same set. ## Why no enforcement today CK is opted out of the rocm-libraries root `.pre-commit-config.yaml`, so the existing `pre-commit` workflow doesn't touch CK. The local CK `.pre-commit-config.yaml` only runs for developers who installed hooks. The **authoritative gate is therefore the new Jenkins stage** in this PR; the local hook is convenience. ## Commit layout (bisect-friendly) 1. `79798aa6261` — **`[ck] Convert reflect/ rendering to ASCII for hipRTC compatibility`** Behavior change, isolated. `TreeFormatter` swaps `├─ / └─ / │ ` for `|- / +- / | ` (3-col width preserved so alignment is unchanged). `conv_description.hpp` swaps `×` for `x` as the dimension separator. `test_conv_description.cpp` expected strings updated in lockstep so the snapshot test stays green. This is the only commit in the series with observable runtime impact. 2. `738fdb0d81c` — **`[ck] Strip non-ASCII bytes from C++ sources for hipRTC compatibility`** Mechanical text cleanup across 53 files. Replacements happen in comments or in `std::cout` strings that are not asserted on by any test. None of the 174 `.inc` files in the tree required edits, but they were in the scan's predicate so the enforcement stage's predicate is a superset of what was scanned. Full replacement table in the commit message. 3. `1d7cd8ba235` — **`[ck] Enforce ASCII-only C/C++ sources for hipRTC compatibility`** - New `projects/composablekernel/script/check_ascii_only.sh` (modeled on `check_copyright_year.sh`). - New entry in `projects/composablekernel/.pre-commit-config.yaml` under the local-hooks block (`types_or: [c++, inc]`). - New `ASCII Only Check` parallel stage in `projects/composablekernel/Jenkinsfile`'s `Static checks` block, mirroring the existing `Clang Format` stage but with `*.inc` added to the find predicate. Always-on, no `RUN_CPPCHECK` gate. The tree is buildable at every commit boundary. Commit 1 leaves 50 known violations; commit 2 leaves 0; commit 3 wires the gate. ## Demo Script output on a synthesized violation: ``` $ printf '// em-dash test \xe2\x80\x94 here\n' > /tmp/bad.cpp $ projects/composablekernel/script/check_ascii_only.sh /tmp/bad.cpp ERROR: /tmp/bad.cpp contains non-ASCII bytes: 1:// em-dash test — here Fix: replace with ASCII (em-dash -> --, smart quotes -> ", arrows -> ->, etc.) $ echo $? 1 ``` Full repo scan after the cleanup commits (note the `-name '*.inc'` clause): ``` $ cd projects/composablekernel && find . -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' \ -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.inc' -o -name '*.cl' \) \ -not -path '*/build/*' -not -path '*/include/rapidjson/*' -print0 \ | xargs -0 -P 8 -n 64 script/check_ascii_only.sh $ echo $? 0 ``` ## Test plan - [ ] Jenkins PR build: confirm new `Static checks -> ASCII Only Check` stage runs green over the full predicate (incl. `*.inc`) and existing `Clang Format` stage is unaffected. - [ ] `test_conv_description` passes against the ASCII tree-formatter output (touched in commit 1). - [ ] Local: `pre-commit run ascii-only-checker --all-files` runs cleanly after installing CK pre-commit hooks via `script/install_precommit.sh`. - [ ] Manually inject a non-ASCII byte in any `.cpp/.hpp/.inc` file, push: confirm Jenkins fails the new stage with a clear error. - [ ] Spot-check a representative subset of touched files under hipRTC compilation to confirm no remaining hipRTC-blocking content (optional, since the static byte check is a sufficient condition for hipRTC preprocessor acceptance on this dimension). 🤖 Generated with [Claude Code](https://claude.com/claude-code) --- .pre-commit-config.yaml | 5 + Jenkinsfile | 18 + .../backends/generated_conv_backend.hpp | 4 +- ...uped_query_attention_forward_wmma_fp16.cpp | 2 +- .../64_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 2 +- .../05_reduce/multiple_reduce_multiblock.cpp | 2 +- ...nvolution_forward_large_tensor_invoker.hpp | 8 +- .../contraction_utils.hpp | 4 +- .../builder/reflect/conv_description.hpp | 48 +- .../builder/reflect/tree_formatter.hpp | 16 +- .../builder/test/test_conv_description.cpp | 262 +++++------ include/ck/utility/sequence_helper.hpp | 5 +- .../arch/amd_buffer_addressing_builtins.hpp | 4 +- include/ck_tile/core/arch/inst_prefetch.hpp | 6 +- .../arch/mma/sparse/sparse_mma_pipeline.hpp | 2 +- .../arch/mma/sparse/sparse_transforms.hpp | 18 +- include/ck_tile/core/numeric/pk_f6.hpp | 17 +- include/ck_tile/core/tensor/store_tile.hpp | 6 +- .../ck_tile/host/reference/reference_topk.hpp | 10 +- .../kernel/batched_contraction_kernel.hpp | 12 +- .../ck_tile/ops/fmha/block/block_dropout.hpp | 4 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 20 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 2 +- .../fused_moe/kernel/fused_moegemm_kernel.hpp | 4 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 4 +- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 90 ++-- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 2 +- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 2 +- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 2 +- .../utils/transform_conv_fwd_to_gemm.hpp | 8 +- ...ice_grouped_conv_bwd_data_xdl_instance.hpp | 14 +- ...ta_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 8 +- ...ata_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 8 +- ...ata_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 8 +- rocm_ck/include/rocm_ck/arch_properties.hpp | 40 +- rocm_ck/include/rocm_ck/gemm_spec.hpp | 80 ++-- rocm_ck/include/rocm_ck/gpu_target.hpp | 2 +- rocm_ck/include/rocm_ck/ops.hpp | 2 +- rocm_ck/include/rocm_ck/resolve.hpp | 44 +- rocm_ck/include/rocm_ck/signature.hpp | 6 +- rocm_ck/include/rocm_ck/spec_json.hpp | 4 +- rocm_ck/include/rocm_ck/validate.hpp | 2 +- .../tests/compile_fail/conflicting_layout.cpp | 4 +- rocm_ck/tests/unit/unit_arch_properties.cpp | 20 +- rocm_ck/tests/unit/unit_gemm_spec.cpp | 20 +- rocm_ck/tests/unit/unit_resolve.cpp | 18 +- .../tests/unit/unit_schema_compatibility.cpp | 2 +- rocm_ck/tests/unit/unit_validate.cpp | 6 +- script/check_ascii_only.sh | 23 + .../mma/pipeline/test_amdgcn_sparse_mma.cpp | 2 +- test/ck_tile/data_type/test_bf16.cpp | 6 +- test/ck_tile/data_type/test_mx_scale.cpp | 2 +- test/ck_tile/data_type/test_pk_fp4.cpp | 2 +- test/ck_tile/data_type/test_pk_fp6.cpp | 4 +- test/ck_tile/fmha/test_fmha_bwd.cpp | 2 +- .../test_ck_tile_grouped_conv_fwd.cpp | 4 +- .../test_cluster_load_async_to_lds.cpp | 16 +- test/data_type/test_bhalf.cpp | 2 +- .../s_prefetch_inst_op_util.hpp | 434 +++++++++--------- .../tile_distribution/tile_distribution_1.cpp | 108 ++--- .../tile_distribution/tile_distribution_2.cpp | 62 +-- .../tile_distribution/tile_distribution_3.cpp | 136 +++--- 63 files changed, 865 insertions(+), 817 deletions(-) create mode 100755 script/check_ascii_only.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index daf3c258d9..bfa77b8445 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,11 @@ repos: verbose: false language: script types_or: [c++, python, shell, cmake] + - id: ascii-only-checker + name: Check for non-ASCII characters in C/C++ sources + entry: projects/composablekernel/script/check_ascii_only.sh + language: script + types_or: [c++, inc] - id: remove-exec-bit name: Remove executable bit from non-executable files entry: projects/composablekernel/script/remove_exec_bit.sh diff --git a/Jenkinsfile b/Jenkinsfile index dfa904fcf8..d0308b7d07 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -330,6 +330,24 @@ pipeline { cleanWs() } } + stage('ASCII Only Check') { + agent{ label rocmnode("nogpu") } + environment{ + setup_args = "NO_CK_BUILD" + execute_cmd = """cd .. && \ + find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.inc' -o -name '*.cl' \\) \ + -not -path '*/build/*' -not -path '*/include/rapidjson/*' \ + -print0 | xargs -0 -P 8 -n 64 script/check_ascii_only.sh""" + } + steps{ + deleteDir() + script { + loadCk(); + ck.buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd) + } + cleanWs() + } + } } } stage("Run Downstream Tests") diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp index b8e4964b13..75a71777b4 100644 --- a/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp @@ -148,7 +148,7 @@ inline GroupedConvKernelInstance::RunFn make_conv_bwd_weight_run_fn() } // ------------------------------------------------------------------------- -// IsSupportedFn factories — check kernel applicability without launching +// IsSupportedFn factories -- check kernel applicability without launching // ------------------------------------------------------------------------- template @@ -181,7 +181,7 @@ inline GroupedConvKernelInstance::IsSupportedFn make_conv_bwd_data_is_supported_ } // ------------------------------------------------------------------------- -// Instance string extraction — get CK Tile GetInstanceString() representation +// Instance string extraction -- get CK Tile GetInstanceString() representation // ------------------------------------------------------------------------- #ifdef CK_EXPERIMENTAL_BUILDER diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp index 66b2aa8508..4b714a5f9e 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -3,7 +3,7 @@ /* Grouped Query Attention, -Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit +Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebron, and Sumit Sanghai. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." arXiv, May 22, 2023. https://doi.org/10.48550/arXiv.2305.13245. diff --git a/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp index 450d1b643f..dd7b43fff6 100644 --- a/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp +++ b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -6,7 +6,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp" // Implementation follows the paper: -// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. "Who Says Elephants Can’t Run: +// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. "Who Says Elephants Can't Run: // Bringing Large Scale MoE Models into Cloud Scale Production." arXiv, November 17, 2022. // https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to // unsigned. diff --git a/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp index 2384dc2aa5..29eedb57db 100644 --- a/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp +++ b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp @@ -245,7 +245,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(pass_op) { - std::cout << "✅ valid results for this operation" << std::endl; + std::cout << "[OK] valid results for this operation" << std::endl; } pass &= pass_op; }); diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index fbe8ffddc6..3a4504f45d 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -122,7 +122,7 @@ struct GroupedConvolutionForwardInvoker } else if(!split_info.should_split) { - std::cout << "[INVOKER] Image is small (" << total_h << "×" << total_w + std::cout << "[INVOKER] Image is small (" << total_h << "x" << total_w << "), split-image not necessary.\n"; std::cout << "[INVOKER] Using regular kernel (Kernel).\n"; } @@ -157,8 +157,8 @@ struct GroupedConvolutionForwardInvoker { std::cout << "Total dimensions: D=" << total_d << " H=" << total_h << " W=" << total_w << "\n"; - std::cout << "Split into pieces: D=" << num_d_pieces << " × H=" << num_h_pieces - << " × W=" << num_w_pieces << " = " << total_pieces + std::cout << "Split into pieces: D=" << num_d_pieces << " x H=" << num_h_pieces + << " x W=" << num_w_pieces << " = " << total_pieces << " total pieces\n"; std::cout << "Base piece size: D=" << (total_d / num_d_pieces) << " H=" << (total_h / num_h_pieces) @@ -167,7 +167,7 @@ struct GroupedConvolutionForwardInvoker else if(NDimSpatial == 2) { std::cout << "Total dimensions: H=" << total_h << " W=" << total_w << "\n"; - std::cout << "Split into pieces: H=" << num_h_pieces << " × W=" << num_w_pieces + std::cout << "Split into pieces: H=" << num_h_pieces << " x W=" << num_w_pieces << " = " << total_pieces << " total pieces\n"; std::cout << "Base piece size: H=" << (total_h / num_h_pieces) << " W=" << (total_w / num_w_pieces) << "\n"; diff --git a/example/ck_tile/41_batched_contraction/contraction_utils.hpp b/example/ck_tile/41_batched_contraction/contraction_utils.hpp index ca327f113c..4fdc57876e 100644 --- a/example/ck_tile/41_batched_contraction/contraction_utils.hpp +++ b/example/ck_tile/41_batched_contraction/contraction_utils.hpp @@ -72,11 +72,11 @@ void print_help(const char* program_name) std::cout << " -e_layout= E tensor layout (default: \"R\")\n\n"; std::cout << "Examples:\n"; - std::cout << " Single batch (12 batches of 256×128):\n"; + std::cout << " Single batch (12 batches of 256x128):\n"; std::cout << " " << program_name << " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; - std::cout << " 2D batch grid (2×3=6 batches):\n"; + std::cout << " 2D batch grid (2x3=6 batches):\n"; std::cout << " " << program_name << " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 3595a6bd98..63b96ec37b 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -80,9 +80,9 @@ class ConvDescription : public Description algo.add("Thread block size: ", traits_.thread_block_size); algo.add("Data tile size: ", traits_.tile_dims.m, - "×", + "x", traits_.tile_dims.n, - "×", + "x", traits_.tile_dims.k); if(traits_.gemm_padding) algo.add("Gemm padding: ", *traits_.gemm_padding); @@ -95,10 +95,10 @@ class ConvDescription : public Description algo.add("Pipeline version: ", traits_.pipeline_version); algo.add("Pipeline scheduler: ", traits_.pipeline_scheduler); auto& warpGemm = algo.add("Warp Gemm parameters:"); - warpGemm.add("subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n); + warpGemm.add("subtile size: ", traits_.warp_gemm.gemm_m, "x", traits_.warp_gemm.gemm_n); warpGemm.add("Number of warp gemm iterations: ", traits_.warp_gemm.m_iter, - "×", + "x", traits_.warp_gemm.n_iter); // Memory Access section @@ -107,29 +107,29 @@ class ConvDescription : public Description auto& aTile = memAccess.add("A Tile transfer:"); aTile.add("Tile dimensions: ", traits_.a_tile_transfer.tile_dimensions.k0, - "×", + "x", traits_.a_tile_transfer.tile_dimensions.m_or_n, - "×", + "x", traits_.a_tile_transfer.tile_dimensions.k1); aTile.add("The innermost K subdimension size: ", traits_.a_tile_transfer.transfer_params.k1); aTile.add("Thread cluster lengths (threads per axis): ", traits_.a_tile_transfer.transfer_params.thread_cluster_dims[0], - "×", + "x", traits_.a_tile_transfer.transfer_params.thread_cluster_dims[1], - "×", + "x", traits_.a_tile_transfer.transfer_params.thread_cluster_dims[2]); aTile.add("Spatial thread distribution over the data tile: ", traits_.a_tile_transfer.transfer_params.thread_cluster_order[0], - "×", + "x", traits_.a_tile_transfer.transfer_params.thread_cluster_order[1], - "×", + "x", traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]); aTile.add("The order of accessing data tile axes: ", traits_.a_tile_transfer.transfer_params.src_access_order[0], - "×", + "x", traits_.a_tile_transfer.transfer_params.src_access_order[1], - "×", + "x", traits_.a_tile_transfer.transfer_params.src_access_order[2]); aTile.add("Vectorized memory access axis index (with contiguous memory): ", traits_.a_tile_transfer.transfer_params.src_vector_dim); @@ -143,29 +143,29 @@ class ConvDescription : public Description auto& bTile = memAccess.add("B Tile transfer:"); bTile.add("Tile dimensions: ", traits_.b_tile_transfer.tile_dimensions.k0, - "×", + "x", traits_.b_tile_transfer.tile_dimensions.m_or_n, - "×", + "x", traits_.b_tile_transfer.tile_dimensions.k1); bTile.add("The innermost K subdimension size: ", traits_.b_tile_transfer.transfer_params.k1); bTile.add("Thread cluster lengths (threads per axis): ", traits_.b_tile_transfer.transfer_params.thread_cluster_dims[0], - "×", + "x", traits_.b_tile_transfer.transfer_params.thread_cluster_dims[1], - "×", + "x", traits_.b_tile_transfer.transfer_params.thread_cluster_dims[2]); bTile.add("Spatial thread distribution over the data tile: ", traits_.b_tile_transfer.transfer_params.thread_cluster_order[0], - "×", + "x", traits_.b_tile_transfer.transfer_params.thread_cluster_order[1], - "×", + "x", traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]); bTile.add("The order of accessing data tile axes: ", traits_.b_tile_transfer.transfer_params.src_access_order[0], - "×", + "x", traits_.b_tile_transfer.transfer_params.src_access_order[1], - "×", + "x", traits_.b_tile_transfer.transfer_params.src_access_order[2]); bTile.add("Vectorized memory access axis index (with contiguous memory): ", traits_.b_tile_transfer.transfer_params.src_vector_dim); @@ -179,15 +179,15 @@ class ConvDescription : public Description auto& cTile = memAccess.add("C Tile transfer:"); cTile.add("Data shuffle (number of gemm instructions per iteration): ", traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, - "×", + "x", traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); cTile.add("Spatial thread distribution used to store data: ", traits_.c_tile_transfer.thread_cluster_dims[0], - "×", + "x", traits_.c_tile_transfer.thread_cluster_dims[1], - "×", + "x", traits_.c_tile_transfer.thread_cluster_dims[2], - "×", + "x", traits_.c_tile_transfer.thread_cluster_dims[3]); cTile.add("Vector access (GMEM write) instruction size: ", traits_.c_tile_transfer.scalar_per_vector); diff --git a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp index ee18d407c1..25d4945fd6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp @@ -10,7 +10,7 @@ namespace ck_tile::reflect { // Tree-node class for building hierarchical tree structures, then rendering them -// with proper indentation and tree-drawing characters (├─, └─, │, etc.) +// with ASCII tree-drawing characters (|-, +-, |, etc.) // // Unlike a streaming API, the tree is built first and rendered afterwards, // so last-child status is determined automatically. @@ -28,11 +28,11 @@ namespace ck_tile::reflect { // Generated Output: // // Root -// ├─ Branch 1 -// │ ├─ Item 1a -// │ └─ Item 1b -// └─ Branch 2 -// └─ Item 2a +// |- Branch 1 +// | |- Item 1a +// | +- Item 1b +// +- Branch 2 +// +- Item 2a class TreeFormatter { public: @@ -76,8 +76,8 @@ class TreeFormatter // Recursive render helper void renderChild(std::ostringstream& oss, const std::string& prefix, bool is_last) const { - oss << prefix << (is_last ? "└─ " : "├─ ") << content_; - std::string child_prefix = prefix + (is_last ? " " : "│ "); + oss << prefix << (is_last ? "+- " : "|- ") << content_; + std::string child_prefix = prefix + (is_last ? " " : "| "); for(size_t i = 0; i < children_.size(); ++i) { oss << '\n'; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 8d943c7a6d..e0d933477b 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -263,49 +263,49 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) EXPECT_THAT(ckr::describe().detailed(), ckt::StringEqWithDiff( // "2D Forward Convolution Kernel\n" - "├─ Signature\n" - "│ ├─ Tensor Type: FP16\n" - "│ ├─ Input Layout: GNHWC\n" - "│ ├─ Weight Layout: GKYXC\n" - "│ ├─ Output Layout: GNHWK\n" - "│ ├─ Input elementwise operation: PASS_THROUGH\n" - "│ ├─ Weights elementwise operation: PASS_THROUGH\n" - "│ └─ Output elementwise operation: PASS_THROUGH\n" - "└─ Algorithm\n" - " ├─ Thread block size: 256\n" - " ├─ Data tile size: 256×256×32\n" - " ├─ Gemm padding: DEFAULT\n" - " ├─ Convolution specialization: DEFAULT\n" - " ├─ Pipeline version: V4\n" - " ├─ Pipeline scheduler: INTRAWAVE\n" - " ├─ Warp Gemm parameters:\n" - " │ ├─ subtile size: 16×16\n" - " │ └─ Number of warp gemm iterations: 8×8\n" - " └─ Memory access:\n" - " ├─ A Tile transfer:\n" - " │ ├─ Tile dimensions: 4×256×8\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Thread cluster lengths (threads per axis): 1×128×2\n" - " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - " │ ├─ The order of accessing data tile axes: 0×1×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 2\n" - " │ ├─ Vector access (LDS write) instruction size: 2\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 0\n" - " ├─ B Tile transfer:\n" - " │ ├─ Tile dimensions: 4×256×8\n" - " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Thread cluster lengths (threads per axis): 1×128×2\n" - " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - " │ ├─ The order of accessing data tile axes: 0×1×2\n" - " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 2\n" - " │ ├─ Vector access (LDS write) instruction size: 2\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 0\n" - " └─ C Tile transfer:\n" - " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - " └─ Vector access (GMEM write) instruction size: 2")); + "|- Signature\n" + "| |- Tensor Type: FP16\n" + "| |- Input Layout: GNHWC\n" + "| |- Weight Layout: GKYXC\n" + "| |- Output Layout: GNHWK\n" + "| |- Input elementwise operation: PASS_THROUGH\n" + "| |- Weights elementwise operation: PASS_THROUGH\n" + "| +- Output elementwise operation: PASS_THROUGH\n" + "+- Algorithm\n" + " |- Thread block size: 256\n" + " |- Data tile size: 256x256x32\n" + " |- Gemm padding: DEFAULT\n" + " |- Convolution specialization: DEFAULT\n" + " |- Pipeline version: V4\n" + " |- Pipeline scheduler: INTRAWAVE\n" + " |- Warp Gemm parameters:\n" + " | |- subtile size: 16x16\n" + " | +- Number of warp gemm iterations: 8x8\n" + " +- Memory access:\n" + " |- A Tile transfer:\n" + " | |- Tile dimensions: 4x256x8\n" + " | |- The innermost K subdimension size: 8\n" + " | |- Thread cluster lengths (threads per axis): 1x128x2\n" + " | |- Spatial thread distribution over the data tile: 0x1x2\n" + " | |- The order of accessing data tile axes: 0x1x2\n" + " | |- Vectorized memory access axis index (with contiguous memory): 2\n" + " | |- Vector access (GMEM read) instruction size: 2\n" + " | |- Vector access (LDS write) instruction size: 2\n" + " | +- LDS data layout padding (to prevent bank conflicts): 0\n" + " |- B Tile transfer:\n" + " | |- Tile dimensions: 4x256x8\n" + " | |- The innermost K subdimension size: 8\n" + " | |- Thread cluster lengths (threads per axis): 1x128x2\n" + " | |- Spatial thread distribution over the data tile: 0x1x2\n" + " | |- The order of accessing data tile axes: 0x1x2\n" + " | |- Vectorized memory access axis index (with contiguous memory): 2\n" + " | |- Vector access (GMEM read) instruction size: 2\n" + " | |- Vector access (LDS write) instruction size: 2\n" + " | +- LDS data layout padding (to prevent bank conflicts): 0\n" + " +- C Tile transfer:\n" + " |- Data shuffle (number of gemm instructions per iteration): 1x1\n" + " |- Spatial thread distribution used to store data: 1x32x1x8\n" + " +- Vector access (GMEM write) instruction size: 2")); } // Test printing of optional parameters num_groups_to_merge, @@ -368,51 +368,51 @@ TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest) EXPECT_THAT(ckr::describe().detailed(), ckt::StringEqWithDiff( // "2D Backward Weight Convolution Kernel\n" - "├─ Signature\n" - "│ ├─ Tensor Type: FP16\n" - "│ ├─ Input Layout: GNHWC\n" - "│ ├─ Weight Layout: GKYXC\n" - "│ ├─ Output Layout: GNHWK\n" - "│ ├─ Input elementwise operation: PASS_THROUGH\n" - "│ ├─ Weights elementwise operation: PASS_THROUGH\n" - "│ └─ Output elementwise operation: PASS_THROUGH\n" - "└─ Algorithm\n" - " ├─ Thread block size: 256\n" - " ├─ Data tile size: 128×128×16\n" - " ├─ Convolution specialization: DEFAULT\n" - " ├─ Pipeline version: V1\n" - " ├─ Pipeline scheduler: DEFAULT\n" - " ├─ Warp Gemm parameters:\n" - " │ ├─ subtile size: 32×32\n" - " │ └─ Number of warp gemm iterations: 4×4\n" - " ├─ Memory access:\n" - " │ ├─ A Tile transfer:\n" - " │ │ ├─ Tile dimensions: 2×128×8\n" - " │ │ ├─ The innermost K subdimension size: 8\n" - " │ │ ├─ Thread cluster lengths (threads per axis): 4×64×1\n" - " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ │ └─ LDS data layout padding (to prevent bank conflicts): 1\n" - " │ ├─ B Tile transfer:\n" - " │ │ ├─ Tile dimensions: 2×128×8\n" - " │ │ ├─ The innermost K subdimension size: 8\n" - " │ │ ├─ Thread cluster lengths (threads per axis): 4×64×1\n" - " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ │ └─ LDS data layout padding (to prevent bank conflicts): 1\n" - " │ └─ C Tile transfer:\n" - " │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - " │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - " │ └─ Vector access (GMEM write) instruction size: 8\n" - " ├─ Max Transpose transfer src scalar per vector: 1\n" - " ├─ Max Transpose dst scalar per vector: 1\n" - " └─ Num groups to merge: 4")); + "|- Signature\n" + "| |- Tensor Type: FP16\n" + "| |- Input Layout: GNHWC\n" + "| |- Weight Layout: GKYXC\n" + "| |- Output Layout: GNHWK\n" + "| |- Input elementwise operation: PASS_THROUGH\n" + "| |- Weights elementwise operation: PASS_THROUGH\n" + "| +- Output elementwise operation: PASS_THROUGH\n" + "+- Algorithm\n" + " |- Thread block size: 256\n" + " |- Data tile size: 128x128x16\n" + " |- Convolution specialization: DEFAULT\n" + " |- Pipeline version: V1\n" + " |- Pipeline scheduler: DEFAULT\n" + " |- Warp Gemm parameters:\n" + " | |- subtile size: 32x32\n" + " | +- Number of warp gemm iterations: 4x4\n" + " |- Memory access:\n" + " | |- A Tile transfer:\n" + " | | |- Tile dimensions: 2x128x8\n" + " | | |- The innermost K subdimension size: 8\n" + " | | |- Thread cluster lengths (threads per axis): 4x64x1\n" + " | | |- Spatial thread distribution over the data tile: 1x0x2\n" + " | | |- The order of accessing data tile axes: 1x0x2\n" + " | | |- Vectorized memory access axis index (with contiguous memory): 2\n" + " | | |- Vector access (GMEM read) instruction size: 8\n" + " | | |- Vector access (LDS write) instruction size: 8\n" + " | | +- LDS data layout padding (to prevent bank conflicts): 1\n" + " | |- B Tile transfer:\n" + " | | |- Tile dimensions: 2x128x8\n" + " | | |- The innermost K subdimension size: 8\n" + " | | |- Thread cluster lengths (threads per axis): 4x64x1\n" + " | | |- Spatial thread distribution over the data tile: 1x0x2\n" + " | | |- The order of accessing data tile axes: 1x0x2\n" + " | | |- Vectorized memory access axis index (with contiguous memory): 2\n" + " | | |- Vector access (GMEM read) instruction size: 8\n" + " | | |- Vector access (LDS write) instruction size: 8\n" + " | | +- LDS data layout padding (to prevent bank conflicts): 1\n" + " | +- C Tile transfer:\n" + " | |- Data shuffle (number of gemm instructions per iteration): 1x1\n" + " | |- Spatial thread distribution used to store data: 1x32x1x8\n" + " | +- Vector access (GMEM write) instruction size: 8\n" + " |- Max Transpose transfer src scalar per vector: 1\n" + " |- Max Transpose dst scalar per vector: 1\n" + " +- Num groups to merge: 4")); } // Test printing of optional parameters num_groups_to_merge, @@ -471,49 +471,49 @@ TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest) EXPECT_THAT(ckr::describe().detailed(), ckt::StringEqWithDiff( // "3D Backward Weight Convolution Kernel\n" - "├─ Signature\n" - "│ ├─ Tensor Type: FP16\n" - "│ ├─ Input Layout: GNDHWC\n" - "│ ├─ Weight Layout: GKZYXC\n" - "│ ├─ Output Layout: GNDHWK\n" - "│ ├─ Input elementwise operation: PASS_THROUGH\n" - "│ ├─ Weights elementwise operation: PASS_THROUGH\n" - "│ └─ Output elementwise operation: PASS_THROUGH\n" - "└─ Algorithm\n" - " ├─ Thread block size: 256\n" - " ├─ Data tile size: 128×128×16\n" - " ├─ Convolution specialization: DEFAULT\n" - " ├─ Pipeline version: V1\n" - " ├─ Pipeline scheduler: DEFAULT\n" - " ├─ Warp Gemm parameters:\n" - " │ ├─ subtile size: 32×32\n" - " │ └─ Number of warp gemm iterations: 4×4\n" - " ├─ Memory access:\n" - " │ ├─ A Tile transfer:\n" - " │ │ ├─ Tile dimensions: 2×128×8\n" - " │ │ ├─ The innermost K subdimension size: 8\n" - " │ │ ├─ Thread cluster lengths (threads per axis): 4×64×1\n" - " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ │ └─ LDS data layout padding (to prevent bank conflicts): 1\n" - " │ ├─ B Tile transfer:\n" - " │ │ ├─ Tile dimensions: 2×128×8\n" - " │ │ ├─ The innermost K subdimension size: 8\n" - " │ │ ├─ Thread cluster lengths (threads per axis): 4×64×1\n" - " │ │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" - " │ │ ├─ The order of accessing data tile axes: 1×0×2\n" - " │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - " │ │ ├─ Vector access (LDS write) instruction size: 8\n" - " │ │ └─ LDS data layout padding (to prevent bank conflicts): 1\n" - " │ └─ C Tile transfer:\n" - " │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - " │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - " │ └─ Vector access (GMEM write) instruction size: 8\n" - " └─ Num gemm k prefetch stage: 1")); + "|- Signature\n" + "| |- Tensor Type: FP16\n" + "| |- Input Layout: GNDHWC\n" + "| |- Weight Layout: GKZYXC\n" + "| |- Output Layout: GNDHWK\n" + "| |- Input elementwise operation: PASS_THROUGH\n" + "| |- Weights elementwise operation: PASS_THROUGH\n" + "| +- Output elementwise operation: PASS_THROUGH\n" + "+- Algorithm\n" + " |- Thread block size: 256\n" + " |- Data tile size: 128x128x16\n" + " |- Convolution specialization: DEFAULT\n" + " |- Pipeline version: V1\n" + " |- Pipeline scheduler: DEFAULT\n" + " |- Warp Gemm parameters:\n" + " | |- subtile size: 32x32\n" + " | +- Number of warp gemm iterations: 4x4\n" + " |- Memory access:\n" + " | |- A Tile transfer:\n" + " | | |- Tile dimensions: 2x128x8\n" + " | | |- The innermost K subdimension size: 8\n" + " | | |- Thread cluster lengths (threads per axis): 4x64x1\n" + " | | |- Spatial thread distribution over the data tile: 1x0x2\n" + " | | |- The order of accessing data tile axes: 1x0x2\n" + " | | |- Vectorized memory access axis index (with contiguous memory): 2\n" + " | | |- Vector access (GMEM read) instruction size: 8\n" + " | | |- Vector access (LDS write) instruction size: 8\n" + " | | +- LDS data layout padding (to prevent bank conflicts): 1\n" + " | |- B Tile transfer:\n" + " | | |- Tile dimensions: 2x128x8\n" + " | | |- The innermost K subdimension size: 8\n" + " | | |- Thread cluster lengths (threads per axis): 4x64x1\n" + " | | |- Spatial thread distribution over the data tile: 1x0x2\n" + " | | |- The order of accessing data tile axes: 1x0x2\n" + " | | |- Vectorized memory access axis index (with contiguous memory): 2\n" + " | | |- Vector access (GMEM read) instruction size: 8\n" + " | | |- Vector access (LDS write) instruction size: 8\n" + " | | +- LDS data layout padding (to prevent bank conflicts): 1\n" + " | +- C Tile transfer:\n" + " | |- Data shuffle (number of gemm instructions per iteration): 1x1\n" + " | |- Spatial thread distribution used to store data: 1x32x1x8\n" + " | +- Vector access (GMEM write) instruction size: 8\n" + " +- Num gemm k prefetch stage: 1")); } TEST(ConvDescriptionTest, DefaultInstanceHasInstanceString) diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index fc1dc795d4..6fb0bb06d7 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -57,10 +57,11 @@ __host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences t // Optimization: Constexpr loop with array lookup instead of recursive template pattern // // Why this approach: -// - Recursive template (OLD): template instantiation for each recursion level → O(N) instantiations +// - Recursive template (OLD): template instantiation for each recursion level -> O(N) +// instantiations // Example: Finding value in Sequence<1,2,3,4,5> requires 5 recursive instantiations // -// - Constexpr loop (NEW): Single function instantiation with runtime loop → O(1) instantiation +// - Constexpr loop (NEW): Single function instantiation with runtime loop -> O(1) instantiation // Same search requires only 1 function instantiation, loop executes at compile-time // // Implementation details: diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 448394dd43..f38c819858 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1443,7 +1443,7 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) // // !!! M0 PRECONDITION - IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!! // -// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3: +// The LDS destination address is taken from M0 (per AMD CDNA3 ISA Sec.10.3: // `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`). // M0 does NOT appear as an operand of these instructions or of the inline // asm below - the compiler cannot see the dependency. Caller must: @@ -1472,7 +1472,7 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) // Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950): // `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000 // 0x007F0000), NOT software-expanded into 4x dword. Same encoding on both -// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but +// arches. The opcode is undocumented in CDNA3 ISA spec Sec.13.6.2 but // supported by the LLVM AMDGPU backend. // // Available on gfx940+ (CDNA3: MI300, MI355, MI350 series). diff --git a/include/ck_tile/core/arch/inst_prefetch.hpp b/include/ck_tile/core/arch/inst_prefetch.hpp index 47170f63e5..1f043d03a1 100644 --- a/include/ck_tile/core/arch/inst_prefetch.hpp +++ b/include/ck_tile/core/arch/inst_prefetch.hpp @@ -5,7 +5,7 @@ #include "ck_tile/core.hpp" -// ─── ISA label markers for two-pass instruction-prefetch offset patching ─── +// --- ISA label markers for two-pass instruction-prefetch offset patching --- // Used by script/patch_prefetch_offset.py to locate prefetch sites and targets // in compiled GPU assembly and patch the koffset field. @@ -16,8 +16,8 @@ #define CK_TILE_XSTR_(x) CK_TILE_STR_(x) #endif -// INST_PREFETCH_TARGET(label) — default mode (mode=0): target is first instruction after -// comment. INST_PREFETCH_TARGET(label, mode) — mode=1 (BLOCK_ENTRY): script scans backward to +// INST_PREFETCH_TARGET(label) -- default mode (mode=0): target is first instruction after +// comment. INST_PREFETCH_TARGET(label, mode) -- mode=1 (BLOCK_ENTRY): script scans backward to // nearest block // label (.LBB*:) and uses the first instruction after that. // Use when the compiler hoists ALU between the block entry diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index 7b2f24dea8..4eb83c509b 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -27,7 +27,7 @@ constexpr inline int getPipelineFlags() * input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and accumulates * results into output WaveTile (C: WaveTileM x WaveTileN). * Like WaveWiseMmaPipeline, this decomposes WaveTile dimensions into fragments and iterates - * internally over FragsM × FragsN × FragsK. The A operand is provided in uncompressed form; + * internally over FragsM x FragsN x FragsK. The A operand is provided in uncompressed form; * 2:4 structured sparsity compression (SparseCompressTransform) is applied. * @tparam ADataType Data type of input WaveTile A * @tparam BDataType Data type of input WaveTile B diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp index f89a062240..0889117902 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -23,7 +23,7 @@ static constexpr index_t idx_words_needed = (CompressedSize * 2 + 31) / 32; * @brief Variable-length container for 2:4 structured sparsity index metadata. * * Each compressed element produces a 2-bit index field encoding the original - * position (0–3) within its group of 4. When composing multiple MMA fragments + * position (0-3) within its group of 4. When composing multiple MMA fragments * in M and K dimensions within a WaveTile, the total number of index bits can * exceed 32. This struct packs the index fields into an array of int32_t words, * sized at compile time. @@ -44,22 +44,22 @@ struct SparseIdxPack * @tparam ADataType The data type of a_vec * @tparam CompressedSize The target compression size * @tparam AVec The vector type of a_vec (deduced) - * @return SparseIdxPack containing **CompressedSize** 2‑bit fields packed + * @return SparseIdxPack containing **CompressedSize** 2-bit fields packed * across one or more int32_t words. Each field encodes the original - * position (0–3) of the corresponding non‑zero element in the input. - * If fewer than CompressedSize non‑zeros are found, remaining fields + * position (0-3) of the corresponding non-zero element in the input. + * If fewer than CompressedSize non-zeros are found, remaining fields * default to 2 (see below). */ template static CK_TILE_DEVICE auto compress_a_impl(AVec& a_vec) { static constexpr index_t NumIdxWords = idx_words_needed; - // idx holds one 2‑bit index per output element (total CompressedSize entries), + // idx holds one 2-bit index per output element (total CompressedSize entries), // packed across NumIdxWords int32_t words. // It is initialized to the pattern 0b10 for every field. This matches - // what the hardware expects when there are fewer than two non‑zero values - // in a 4‑element group – the unused output is treated as coming from slot 2. - // The loop below will clear and set each field as real non‑zeros are seen. + // what the hardware expects when there are fewer than two non-zero values + // in a 4-element group - the unused output is treated as coming from slot 2. + // The loop below will clear and set each field as real non-zeros are seen. SparseIdxPack idx{}; static_for<0, CompressedSize, 1>{}([&](auto k) { constexpr uint32_t bit_pos = static_cast(k) * 2u; @@ -76,7 +76,7 @@ static CK_TILE_DEVICE auto compress_a_impl(AVec& a_vec) if(static_cast(a_vec[i * 4 + j]) != 0.0f) { nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; - // clear the two‑bit field for this output and insert j + // clear the two-bit field for this output and insert j const uint32_t field_idx = static_cast(i) * 2u + static_cast(non_zero_pos); const uint32_t bit_pos = field_idx * 2u; diff --git a/include/ck_tile/core/numeric/pk_f6.hpp b/include/ck_tile/core/numeric/pk_f6.hpp index 3c808ec2f3..6910b5c1c3 100644 --- a/include/ck_tile/core/numeric/pk_f6.hpp +++ b/include/ck_tile/core/numeric/pk_f6.hpp @@ -391,16 +391,17 @@ struct numeric // Value layout (positive): // exp=000,mant=00 -> 0 (zero) // exp=000,mant=01 -> smallest positive subnormal - // exp=000,mant=11 -> largest positive subnormal (≈ 0.0625) - // exp=001,mant=00 -> smallest positive normal (≈ 0.25) - // exp=111,mant=11 -> largest positive normal (≈ 28.0) + // exp=000,mant=11 -> largest positive subnormal (~= 0.0625) + // exp=001,mant=00 -> smallest positive normal (~= 0.25) + // exp=111,mant=11 -> largest positive normal (~= 28.0) - static constexpr uint8_t binary_min_normal = 0b000100; // smallest positive normal (≈ 0.25) - static constexpr uint8_t binary_max_normal = 0b011111; // largest positive normal (≈ 28.0) - static constexpr uint8_t binary_lowest_normal = 0b111111; // most negative normal (≈ -28.0) + static constexpr uint8_t binary_min_normal = 0b000100; // smallest positive normal (~= 0.25) + static constexpr uint8_t binary_max_normal = 0b011111; // largest positive normal (~= 28.0) + static constexpr uint8_t binary_lowest_normal = 0b111111; // most negative normal (~= -28.0) static constexpr uint8_t binary_min_subnorm = 0b000001; // smallest positive subnormal - static constexpr uint8_t binary_max_subnorm = 0b000011; // largest positive subnormal (≈ 0.0625) - static constexpr uint8_t binary_zero = 0b000000; // zero + static constexpr uint8_t binary_max_subnorm = + 0b000011; // largest positive subnormal (~= 0.0625) + static constexpr uint8_t binary_zero = 0b000000; // zero CK_TILE_HOST_DEVICE static constexpr pk_bf6_t min() { diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index 78974acdc6..178e65bb62 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -72,7 +72,7 @@ store_tile(tile_window_with_static_lengths& t tile_window.store(dstr_tensor); } -// Raw variant — same reconstruction cost as store_tile above. +// Raw variant -- same reconstruction cost as store_tile above. template {}); } -// Raw variant — same fast path as above. +// Raw variant -- same fast path as above. template , sequence<1, 0>>{}; - // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + // Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd. constexpr auto randval_block_inner_part_dstr_encoding = typename WarpGemmDispatcher static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { const auto randval = generate_randval(i_m0, i_n0); // Drop values of P based on the generated probabilities, negative sign is used to - // distinguish such values ​​later in bwd pipeline. + // distinguish such values later in bwd pipeline. constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 75b91f3569..497cf8d7a3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -30,7 +30,7 @@ namespace ck_tile { // Per-CU state for group-mode deterministic persistent scheduling. -// alignas(16): enables aligned 128-bit loads; sizeof == 32 (6×4 + 8 pad). +// alignas(16): enables aligned 128-bit loads; sizeof == 32 (6x4 + 8 pad). struct alignas(16) FmhaBwdGroupPersistentCuState { index_t w_lo; // global position of this CU's first K-chunk (= pb + head*hw + c*sq) @@ -44,7 +44,7 @@ struct alignas(16) FmhaBwdGroupPersistentCuState // Per-batch precomputed values used in the group-mode persistent dispatch loop. // Avoids per-iteration reads from seqstart_q/k_ptr and nsplits_ptr. -// alignas(16): sizeof == 16 (3×4 + 4 pad), fits in a single 128-bit load. +// alignas(16): sizeof == 16 (3x4 + 4 pad), fits in a single 128-bit load. struct alignas(16) FmhaBwdBatchState { index_t sq; // seqlen_q for this batch (seqstart_q[b+1] - seqstart_q[b]) @@ -67,9 +67,9 @@ struct FmhaBwdWorkspaceManager // [OPTIONAL, only for deterministic group mode persistent] // FmhaBwdGroupPersistentCuState cu_state[num_cus] - // — per-CU packed dispatch state (ibatch, isplit, head_start, c_start, w_lo) + // -- per-CU packed dispatch state (ibatch, isplit, head_start, c_start, w_lo) // FmhaBwdBatchState batch_state[batch] - // — per-batch precomputed sq / nc / nsplits + // -- per-batch precomputed sq / nc / nsplits // GPU WORKSPACE BELOW (read & written by kernels): @@ -289,7 +289,7 @@ struct FmhaBwdWorkspaceManager // nsplits matches the set of slots actually written by atomic_add. // CUs with c_start >= nc start past the head's K-rows (advance to // next head); their isplit would otherwise pad nsplits with a slot - // that nobody writes — reduction would read garbage from it. + // that nobody writes -- reduction would read garbage from it. if(c_start < nc) batch_states[b].nsplits = max(batch_states[b].nsplits, cu_states[c].isplit + 1); @@ -311,7 +311,7 @@ struct FmhaBwdWorkspaceManager { cu_states[c].w_lo = total_w; cu_states[c].w_hi = total_w; - cu_states[c].ibatch = batch_size; // sentinel → early return on GPU + cu_states[c].ibatch = batch_size; // sentinel -> early return on GPU cu_states[c].isplit = 0; cu_states[c].head_start = 0; cu_states[c].c_start = 0; @@ -384,13 +384,13 @@ struct FmhaBwdWorkspaceManager // (~20x larger than the actual region for large seqlen_k). if constexpr(kUsePersistent && kIsGroupMode) return false; - // Persistent (batch and group): uses atomic_add → buffer must start at zero + // Persistent (batch and group): uses atomic_add -> buffer must start at zero // so that accumulated dq values are correct. - // Non-deterministic: uses atomic_add → buffer must start at zero. + // Non-deterministic: uses atomic_add -> buffer must start at zero. if constexpr(kUsePersistent || !kIsDeterministic) return true; // Non-persistent deterministic: uses set, but causal mask may skip some tiles - // leaving dq_acc slots unwritten — zero them out first. + // leaving dq_acc slots unwritten -- zero them out first. return kHasMask; } @@ -1631,7 +1631,7 @@ struct FmhaBwdDQDKDVKernel }(); // kUseKSplit && !kUsePersistent is true only for QrQtrDor+deterministic, - // which writes dq directly (not through dq_acc splits) — use 'set'. + // which writes dq directly (not through dq_acc splits) -- use 'set'. // All other deterministic paths are persistent and use 'atomic_add': // a single CU may process multiple chunks of the same (batch, head, isplit) // sequentially, so contributions must accumulate rather than overwrite. diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 444fdef69b..68f54662d4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -1315,7 +1315,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #if CK_TILE_FMHA_FWD_FAST_EXP2 // For KV_BLOCKSCALE: precompute (m - shift) once per row // exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift - // This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply + // This scales P by 2^shift (~=448 for fp8_e4m3) without explicit multiply auto validated_m = get_validated_m(m[i_idx]); auto row_max = scale_s * validated_m; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 15e6e5eb43..09b73c03d3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -852,7 +852,7 @@ struct BlockFmhaPipelineQRKSVS }; // Conditional rescaling: skip o_acc rescale when correction factor - // exp2(acc_scale_log2) is negligible (< exp2(-8) ≈ 0.004, below BF16 + // exp2(acc_scale_log2) is negligible (< exp2(-8) ~= 0.004, below BF16 // precision). Adapted from FlashAttention-4 (Tri Dao, 2025). // Eliminates 70-90% of rescale operations in practice. // diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index 2aca7527be..600c088778 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -51,14 +51,14 @@ // // * different from vLLM // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id -// 2)need sorted_weight_ptr +// 2)need sorted_weight_ptr // 3) use num_sorted_tiles_ptr, already divided by M_a // // * below used for indexing // 1) sorted_token_ids_ptr [max_num_tokens_padded] // 2) sorted_weight_ptr // 3) sorted_expert_ids_ptr -// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) +// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) // diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 2223b6bb32..85e224c6b0 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -122,14 +122,14 @@ namespace ck_tile { // // * different from vLLM // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id -// 2)need sorted_weight_ptr +// 2)need sorted_weight_ptr // 3) use num_sorted_tiles_ptr, already divided by M_a // // * below used for indexing // 1) sorted_token_ids_ptr [max_num_tokens_padded] // 2) sorted_weight_ptr // 3) sorted_expert_ids_ptr -// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) +// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index cdfbafc6c0..43062cd470 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -505,9 +505,9 @@ struct GemmClusterTilePartitioner each output TILE position (tile_m, tile_n). Values 0-5 represent the 6 different clusters in the grid. - ═══════════════════════════════════════════════════════════════════════════ + =========================================================================== Pattern::ContiguousBlock (ClusterTilePattern::ContiguousBlock) - ═══════════════════════════════════════════════════════════════════════════ + =========================================================================== DESCRIPTION: Tiles are assigned in CONTIGUOUS blocks within each cluster. Each cluster @@ -516,19 +516,19 @@ struct GemmClusterTilePartitioner TILE ASSIGNMENT (each cell shows which cluster processes that tile): N-> 0 1 2 3 - ┌────────────────────────┐ - M 0 │ │ 0 │ 0 │ 3 │ 3 │ - │ ├────────────────────┤ - │ 1 │ │ 0 │ 0 │ 3 │ 3 │ - │ ├────────────────────┤ - ↓ 2 │ │ 1 │ 1 │ 4 │ 4 │ - │ ├────────────────────┤ - 3 │ │ 1 │ 1 │ 4 │ 4 │ - │ ├────────────────────┤ - 4 │ │ 2 │ 2 │ 5 │ 5 │ - │ ├────────────────────┤ - 5 │ │ 2 │ 2 │ 5 │ 5 │ - └────────────────────────┘ + +------------------------+ + M 0 | | 0 | 0 | 3 | 3 | + | +--------------------+ + | 1 | | 0 | 0 | 3 | 3 | + | +--------------------+ + v 2 | | 1 | 1 | 4 | 4 | + | +--------------------+ + 3 | | 1 | 1 | 4 | 4 | + | +--------------------+ + 4 | | 2 | 2 | 5 | 5 | + | +--------------------+ + 5 | | 2 | 2 | 5 | 5 | + +------------------------+ CLUSTER LAYOUT: - Cluster 0: tiles (0,0), (0,1), (1,0), (1,1) - Top-left block @@ -538,9 +538,9 @@ struct GemmClusterTilePartitioner - Cluster 4: tiles (2,2), (2,3), (3,2), (3,3) - Middle-right block - Cluster 5: tiles (4,2), (4,3), (5,2), (5,3) - Bottom-right block - ═══════════════════════════════════════════════════════════════════════════ + =========================================================================== Pattern::InterleavedBoth (ClusterTilePattern::InterleavedBoth) - ═══════════════════════════════════════════════════════════════════════════ + =========================================================================== DESCRIPTION: Tiles are INTERLEAVED along both M and N dimensions. Within each cluster, @@ -549,19 +549,19 @@ struct GemmClusterTilePartitioner TILE ASSIGNMENT (interleaved along both M and N): N-> 0 1 2 3 - ┌────────────────────────┐ - M 0 │ │ 0 │ 3 │ 0 │ 3 │ - │ ├────────────────────┤ - │ 1 │ │ 1 │ 4 │ 1 │ 4 │ - │ ├────────────────────┤ - ↓ 2 │ │ 2 │ 5 │ 2 │ 5 │ - │ ├────────────────────┤ - 3 │ │ 0 │ 3 │ 0 │ 3 │ - │ ├────────────────────┤ - 4 │ │ 1 │ 4 │ 1 │ 4 │ - │ ├────────────────────┤ - 5 │ │ 2 │ 5 │ 2 │ 5 │ - └────────────────────────┘ + +------------------------+ + M 0 | | 0 | 3 | 0 | 3 | + | +--------------------+ + | 1 | | 1 | 4 | 1 | 4 | + | +--------------------+ + v 2 | | 2 | 5 | 2 | 5 | + | +--------------------+ + 3 | | 0 | 3 | 0 | 3 | + | +--------------------+ + 4 | | 1 | 4 | 1 | 4 | + | +--------------------+ + 5 | | 2 | 5 | 2 | 5 | + +------------------------+ CLUSTER LAYOUT: - Cluster 0: tiles (0,0), (0,2), (3,0), (3,2) - Strided along M and N @@ -571,9 +571,9 @@ struct GemmClusterTilePartitioner - Cluster 4: tiles (1,1), (1,3), (4,1), (4,3) - Strided along M and N - Cluster 5: tiles (2,1), (2,3), (5,1), (5,3) - Strided along M and N - ═══════════════════════════════════════════════════════════════════════════ + =========================================================================== Pattern::InterleavedM (ClusterTilePattern::InterleavedM) - ═══════════════════════════════════════════════════════════════════════════ + =========================================================================== DESCRIPTION: Tiles are INTERLEAVED along the M dimension while contiguous along N. @@ -582,19 +582,19 @@ struct GemmClusterTilePartitioner TILE ASSIGNMENT (interleaved along M, contiguous along N): N-> 0 1 2 3 - ┌────────────────────────┐ - M 0 │ │ 0 │ 0 │ 3 │ 3 │ - │ ├────────────────────┤ - │ 1 │ │ 1 │ 1 │ 4 │ 4 │ - │ ├────────────────────┤ - ↓ 2 │ │ 2 │ 2 │ 5 │ 5 │ - │ ├────────────────────┤ - 3 │ │ 0 │ 0 │ 3 │ 3 │ - │ ├────────────────────┤ - 4 │ │ 1 │ 1 │ 4 │ 4 │ - │ ├────────────────────┤ - 5 │ │ 2 │ 2 │ 5 │ 5 │ - └────────────────────────┘ + +------------------------+ + M 0 | | 0 | 0 | 3 | 3 | + | +--------------------+ + | 1 | | 1 | 1 | 4 | 4 | + | +--------------------+ + v 2 | | 2 | 2 | 5 | 5 | + | +--------------------+ + 3 | | 0 | 0 | 3 | 3 | + | +--------------------+ + 4 | | 1 | 1 | 4 | 4 | + | +--------------------+ + 5 | | 2 | 2 | 5 | 5 | + +------------------------+ CLUSTER LAYOUT: - Cluster 0: tiles (0,0), (0,1), (3,0), (3,1) - Strided along M, contiguous N diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index d2c5e4598c..c2b5ea3951 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -200,7 +200,7 @@ struct GroupedGemmKernel HIP_CHECK_ERROR( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); // TODO: the below is a temporary fix which is due to kernel metadata - // .workgroup_processor_mode isn’t used correctly in clr for gfx1250. Will removed when clr + // .workgroup_processor_mode isn't used correctly in clr for gfx1250. Will removed when clr // and compiler team fix this. occupancy = occupancy > 0 ? occupancy : 1; const int grid_size = get_available_compute_units(s) * occupancy; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index b1bdb5afcf..f6e7f14bc3 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -147,7 +147,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // How often (in MFMA units) we should insert DS (LDS) operations. constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); // How often (in MFMA units) we should insert VMEM buffer loads. - // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // buffer_load_rep ~= "MFMA per VMEM_READ", clamped so that one buffer_load // is assumed to cover at most 4 MFMA instructions. constexpr index_t buffer_load_rep = min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index d0f52e73fc..516ea19f6e 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -136,7 +136,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // How often (in MFMA units) we should insert DS (LDS) operations. constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); // How often (in MFMA units) we should insert VMEM buffer loads. - // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // buffer_load_rep ~= "MFMA per VMEM_READ", clamped so that one buffer_load // is assumed to cover at most 4 MFMA instructions. constexpr index_t buffer_load_rep = min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index 8e9d808d4c..2d47b7d802 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -6,9 +6,9 @@ #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" namespace ck_tile { -// ═══════════════════════════════════════════════════════════════════════ +// ======================================================================= // Split-Image Information Structure -// ═══════════════════════════════════════════════════════════════════════ +// ======================================================================= // This structure holds all information needed to perform split-image // NOTE: SplitImageInfo struct deleted - was only used by deleted recursive split code // Current split-image implementation is in grouped_convolution_forward_invoker.hpp @@ -1517,9 +1517,9 @@ struct TransformConvFwdToGemm } } - // ═══════════════════════════════════════════════════════════════════════ + // ======================================================================= // Split-Image Calculation (AFTER Split-N) - // ═══════════════════════════════════════════════════════════════════════ + // ======================================================================= // This method calculates split-image information using N_ (after Split-N). // This ensures correct offset calculations when both Split-N and Split-Image // are active simultaneously. diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index de5cf4e1cc..d6b6fdef8d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -167,7 +167,7 @@ using device_grouped_conv_bwd_data_xdl_f16_noshuffle_instances = // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // f16_f16_f32_f16 — noshuffle epilogue (CDEBlockTransferScalarPerVector_NPerBlock = 1) + // f16_f16_f32_f16 -- noshuffle epilogue (CDEBlockTransferScalarPerVector_NPerBlock = 1) DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 1>, @@ -333,9 +333,9 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = std::tuple< // clang-format on >; -// bf16_bf16_f32_bf16 — noshuffle epilogue (CDEBlockTransferScalarPerVector_NPerBlock = 1) +// bf16_bf16_f32_bf16 -- noshuffle epilogue (CDEBlockTransferScalarPerVector_NPerBlock = 1) // Same tile shapes as bf16_instances but with ScalarPerVector=1, enabling the no-shuffle fast path -// (VGPR → Global direct write, 0 LDS barriers) instead of CShuffle (VGPR → LDS → Global, 8 +// (VGPR -> Global direct write, 0 LDS barriers) instead of CShuffle (VGPR -> LDS -> Global, 8 // barriers). template , S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 1>, @@ -552,7 +552,7 @@ using device_grouped_conv_bwd_data_xdl_f32_noshuffle_instances = // clang-format on >; -// bf16 — BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl +// bf16 -- BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl // instances. The key difference from bf16_instances: BBlockTransfer uses S<4, BlockSize/4, 1> // thread cluster and S<2, 0, 1> arrange order, which gives full thread utilization for B-matrix // loads. These are optimal when opt3 flat descriptor path is active (G=1, 2D convolutions). @@ -584,7 +584,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_nongrouped_match_instances = std::tu // clang-format on >; -// f16 — BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl +// f16 -- BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl // instances. template ; -// f32 — BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl +// f32 -- BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl // instances. F32 uses K1=4, KPerBlock=16, and smaller scalar-per-vector values. template {}); - // 3. Default — noshuffle epilogue + // 3. Default -- noshuffle epilogue add_device_operation_instances( instances, device_grouped_conv_bwd_data_xdl_bf16_noshuffle_instances<2, @@ -50,7 +50,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( Empty_Tuple, NHWGC, ConvBwdDataDefault>{}); - // 4. Filter1x1Stride1Pad0 — noshuffle epilogue + // 4. Filter1x1Stride1Pad0 -- noshuffle epilogue add_device_operation_instances(instances, device_grouped_conv_bwd_data_xdl_bf16_noshuffle_instances< 2, @@ -59,7 +59,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); - // 5. Default — nongrouped_match instances + // 5. Default -- nongrouped_match instances add_device_operation_instances( instances, device_grouped_conv_bwd_data_xdl_bf16_nongrouped_match_instances<2, @@ -68,7 +68,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( Empty_Tuple, NHWGC, ConvBwdDataDefault>{}); - // 6. Filter1x1Stride1Pad0 — nongrouped_match instances + // 6. Filter1x1Stride1Pad0 -- nongrouped_match instances add_device_operation_instances(instances, device_grouped_conv_bwd_data_xdl_bf16_nongrouped_match_instances< 2, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 3bbd4a37e5..e7ec4c54ab 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -41,7 +41,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); - // 3. Default — noshuffle epilogue + // 3. Default -- noshuffle epilogue add_device_operation_instances( instances, device_grouped_conv_bwd_data_xdl_f16_noshuffle_instances<2, @@ -50,7 +50,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( Empty_Tuple, NHWGC, ConvBwdDataDefault>{}); - // 4. Filter1x1Stride1Pad0 — noshuffle epilogue + // 4. Filter1x1Stride1Pad0 -- noshuffle epilogue add_device_operation_instances(instances, device_grouped_conv_bwd_data_xdl_f16_noshuffle_instances< 2, @@ -59,7 +59,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); - // 5. Default — nongrouped_match instances + // 5. Default -- nongrouped_match instances add_device_operation_instances( instances, device_grouped_conv_bwd_data_xdl_f16_nongrouped_match_instances<2, @@ -68,7 +68,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( Empty_Tuple, NHWGC, ConvBwdDataDefault>{}); - // 6. Filter1x1Stride1Pad0 — nongrouped_match instances + // 6. Filter1x1Stride1Pad0 -- nongrouped_match instances add_device_operation_instances(instances, device_grouped_conv_bwd_data_xdl_f16_nongrouped_match_instances< 2, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 344c35c5ca..9df22a7bd8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -41,7 +41,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); - // 3. Default — noshuffle epilogue + // 3. Default -- noshuffle epilogue add_device_operation_instances( instances, device_grouped_conv_bwd_data_xdl_f32_noshuffle_instances<2, @@ -50,7 +50,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( Empty_Tuple, NHWGC, ConvBwdDataDefault>{}); - // 4. Filter1x1Stride1Pad0 — noshuffle epilogue + // 4. Filter1x1Stride1Pad0 -- noshuffle epilogue add_device_operation_instances(instances, device_grouped_conv_bwd_data_xdl_f32_noshuffle_instances< 2, @@ -59,7 +59,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); - // 5. Default — nongrouped_match instances + // 5. Default -- nongrouped_match instances add_device_operation_instances( instances, device_grouped_conv_bwd_data_xdl_f32_nongrouped_match_instances<2, @@ -68,7 +68,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( Empty_Tuple, NHWGC, ConvBwdDataDefault>{}); - // 6. Filter1x1Stride1Pad0 — nongrouped_match instances + // 6. Filter1x1Stride1Pad0 -- nongrouped_match instances add_device_operation_instances(instances, device_grouped_conv_bwd_data_xdl_f32_nongrouped_match_instances< 2, diff --git a/rocm_ck/include/rocm_ck/arch_properties.hpp b/rocm_ck/include/rocm_ck/arch_properties.hpp index 6e58ba2900..baeec61c05 100644 --- a/rocm_ck/include/rocm_ck/arch_properties.hpp +++ b/rocm_ck/include/rocm_ck/arch_properties.hpp @@ -1,7 +1,7 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT // -// Role: types — GPU target properties and target set (consteval bitset). +// Role: types -- GPU target properties and target set (consteval bitset). // // UPSTREAM CANDIDATE: This header prototypes functionality that should // eventually live in ck_tile/core/arch/arch_properties.hpp as the single @@ -60,7 +60,7 @@ constexpr TargetProperties properties(GpuTarget target) case GpuTarget::gfx1150: return {.wavefront_size = 32, .arch_family = ArchFamily::RDNA}; case GpuTarget::gfx1151: return {.wavefront_size = 32, .arch_family = ArchFamily::RDNA}; case GpuTarget::_count: - default: throw "unsupported GpuTarget — add a case to properties() for new targets"; + default: throw "unsupported GpuTarget -- add a case to properties() for new targets"; } } @@ -80,7 +80,7 @@ constexpr bool isRDNA(GpuTarget target) constexpr int wavefrontSize(GpuTarget target) { return properties(target).wavefront_size; } // ============================================================================ -// TargetSet — consteval bitset over GpuTarget values +// TargetSet -- consteval bitset over GpuTarget values // ============================================================================ /// Compile-time set of GPU targets. Structural type (usable as NTTP). @@ -93,13 +93,13 @@ constexpr int wavefrontSize(GpuTarget target) { return properties(target).wavefr /// Specific: TargetSet::only(GpuTarget::gfx942) /// /// CK Tile mapping: -/// TargetSet::cdna() → enable_if_target_arch_cdna_t -/// TargetSet::rdna() → enable_if_target_arch_rdna_t -/// TargetSet::family_gfx9() → enable_if_target_family_gfx9_t -/// TargetSet::family_gfx11() → __gfx11__ preprocessor grouping -/// TargetSet::family_gfx115() → __gfx115__ preprocessor grouping -/// TargetSet::only(gfx942, gfx950) → enable_if_target_id_t -/// TargetSet::cdna().excluding(gfx90a) → is_any_value_of(T::TARGET_ID, GFX942, GFX950) +/// TargetSet::cdna() -> enable_if_target_arch_cdna_t +/// TargetSet::rdna() -> enable_if_target_arch_rdna_t +/// TargetSet::family_gfx9() -> enable_if_target_family_gfx9_t +/// TargetSet::family_gfx11() -> __gfx11__ preprocessor grouping +/// TargetSet::family_gfx115() -> __gfx115__ preprocessor grouping +/// TargetSet::only(gfx942, gfx950) -> enable_if_target_id_t +/// TargetSet::cdna().excluding(gfx90a) -> is_any_value_of(T::TARGET_ID, GFX942, GFX950) struct TargetSet { uint64_t bits = 0; @@ -114,7 +114,7 @@ struct TargetSet static constexpr int bitIndex(GpuTarget target) { if(target >= GpuTarget::_count) - throw "GpuTarget out of range — value must be a valid enum member, not _count"; + throw "GpuTarget out of range -- value must be a valid enum member, not _count"; return static_cast(target); } @@ -252,7 +252,7 @@ struct TargetSet wf = target_wf; else if(wf != target_wf) throw "wavefront_size() requires all targets in the set to have " - "the same wavefront size — this set mixes wave64 (CDNA) and " + "the same wavefront size -- this set mixes wave64 (CDNA) and " "wave32 (RDNA) targets. Split with intersect_with(cdna()) or " "intersect_with(rdna())."; } @@ -290,7 +290,7 @@ struct TargetSet }; // ============================================================================ -// Wave tile validation — single source of truth +// Wave tile validation -- single source of truth // ============================================================================ // Based on CK Tile's WarpGemmDispatcher specializations. // See: ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp (MFMA builtins) @@ -301,16 +301,16 @@ struct TargetSet /// on a specific target. consteval bool isValidWaveTile(DataType a_dtype, int m, int n, int k, GpuTarget target) { - // RDNA targets: WMMA — fixed 16x16x16 tile shape + // RDNA targets: WMMA -- fixed 16x16x16 tile shape if(isRDNA(target)) { if(m != 16 || n != 16 || k != 16) return false; - // RDNA (gfx11xx) WMMA: fp16, bf16, int8 — all targets share 16×16×16 tile + // RDNA (gfx11xx) WMMA: fp16, bf16, int8 -- all targets share 16x16x16 tile return a_dtype == DataType::FP16 || a_dtype == DataType::BF16 || a_dtype == DataType::I8; } - // CDNA MFMA tiles — common across gfx90a, gfx942, gfx950 + // CDNA MFMA tiles -- common across gfx90a, gfx942, gfx950 if(a_dtype == DataType::FP32) { if(m == 16 && n == 16 && (k == 4 || k == 8 || k == 16)) @@ -333,7 +333,7 @@ consteval bool isValidWaveTile(DataType a_dtype, int m, int n, int k, GpuTarget return true; } - // INT8 MFMA — int8x int8→int32 accumulation + // INT8 MFMA -- int8x int8->int32 accumulation if(a_dtype == DataType::I8) { if(m == 32 && n == 32 && k == 16) @@ -342,7 +342,7 @@ consteval bool isValidWaveTile(DataType a_dtype, int m, int n, int k, GpuTarget return true; } - // FP8/BF8 MFMA — architecture-dependent + // FP8/BF8 MFMA -- architecture-dependent if(a_dtype == DataType::FP8_FNUZ || a_dtype == DataType::BF8_FNUZ) { // gfx90a: no FP8 MFMA support @@ -362,9 +362,9 @@ consteval bool isValidWaveTile(DataType a_dtype, int m, int n, int k, GpuTarget return true; } - // FP8_OCP/BF8_OCP — not yet supported + // FP8_OCP/BF8_OCP -- not yet supported if(a_dtype == DataType::FP8_OCP || a_dtype == DataType::BF8_OCP) - throw "FP8_OCP/BF8_OCP not yet supported in GEMM — use FP8_FNUZ/BF8_FNUZ"; + throw "FP8_OCP/BF8_OCP not yet supported in GEMM -- use FP8_FNUZ/BF8_FNUZ"; return false; } diff --git a/rocm_ck/include/rocm_ck/gemm_spec.hpp b/rocm_ck/include/rocm_ck/gemm_spec.hpp index e3ff06a61e..669770c7e6 100644 --- a/rocm_ck/include/rocm_ck/gemm_spec.hpp +++ b/rocm_ck/include/rocm_ck/gemm_spec.hpp @@ -1,7 +1,7 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT // -// Role: meta — GemmSpec structural NTTP descriptor and consteval factory. +// Role: meta -- GemmSpec structural NTTP descriptor and consteval factory. // // SHARED header: compiled in both host and device (--cuda-device-only) passes. // Contains structural types, consteval makeSpec() factory, and named accessors. @@ -10,8 +10,8 @@ // Wave tile validation and target properties live in arch_properties.hpp. // // Compilation boundary: -// _spec.hpp (this) — schema types + consteval factory (both passes) -// _dev.hpp — CK Tile bridge + __device__ code (device pass only, #error on host) +// _spec.hpp (this) -- schema types + consteval factory (both passes) +// _dev.hpp -- CK Tile bridge + __device__ code (device pass only, #error on host) #pragma once @@ -38,15 +38,15 @@ namespace rocm_ck { /// epilogue chain. Enum is structural (NTTP-compatible); std::variant is not. /// /// Binary ops (Add, Mul) fold over D tensors via parameter pack: -/// Add — result += D0 [+ D1] (bias addition) -/// Mul — result *= D0 [* D1] (scaling) +/// Add -- result += D0 [+ D1] (bias addition) +/// Mul -- result *= D0 [* D1] (scaling) /// /// Unary ops transform the accumulator in place: -/// Relu — max(0, x) -/// FastGelu — approximate GELU: x * sigmoid(1.702 * x) -/// Gelu — exact GELU: 0.5 * x * (1 + erf(x / sqrt(2))) -/// Silu — x * sigmoid(x) (aka Swish with beta=1) -/// Sigmoid — 1 / (1 + exp(-x)) +/// Relu -- max(0, x) +/// FastGelu -- approximate GELU: x * sigmoid(1.702 * x) +/// Gelu -- exact GELU: 0.5 * x * (1 + erf(x / sqrt(2))) +/// Silu -- x * sigmoid(x) (aka Swish with beta=1) +/// Sigmoid -- 1 / (1 + exp(-x)) /// /// Operations compose as an ordered sequence in GemmSpec::epilogue_ops[]. /// The Signature's operator chain (AddOp -> ReluOp) maps directly to this @@ -71,22 +71,22 @@ inline constexpr int kMaxEpilogueOps = 4; /// Pipeline implementation strategy for the GEMM kernel. /// -/// V1: Simple pipeline — A/B from global memory, C in registers. +/// V1: Simple pipeline -- A/B from global memory, C in registers. /// Uses GemmPipelineProblem + GemmPipelineAGmemBGmemCRegV1. /// -/// V3: Compute-optimized pipeline — software-pipelined loads. +/// V3: Compute-optimized pipeline -- software-pipelined loads. /// Uses UniversalGemmPipelineProblem + GemmPipelineAgBgCrCompV3. /// Better compute utilization through overlapped memory/compute. /// -/// V4: Compute double-buffer — ping-pong LDS layout. +/// V4: Compute double-buffer -- ping-pong LDS layout. /// Uses UniversalGemmPipelineProblem + GemmPipelineAgBgCrCompV4. /// Better compute/memory overlap through dual LDS buffers. /// -/// Memory: Memory-optimized pipeline — A/B from global memory through LDS. +/// Memory: Memory-optimized pipeline -- A/B from global memory through LDS. /// Uses UniversalGemmPipelineProblem + GemmPipelineAgBgCrMem. /// Supports both Intrawave and Interwave scheduling. /// -/// Preshuffle: Weight preshuffle pipeline — B matrix pre-rearranged for +/// Preshuffle: Weight preshuffle pipeline -- B matrix pre-rearranged for /// optimal LDS loads. Uses WeightPreshufflePipelineAGmemBGmemCRegV2. /// Requires A=RowMajor, B=ColumnMajor. Host must call preshuffle on B /// before kernel launch. @@ -105,11 +105,11 @@ enum class Pipeline /// operations within each wave. This is instruction-level scheduling, /// not spatial decomposition (which is TilePartitioner's concern). /// -/// Intrawave: Synchronous — all waves in a workgroup synchronize after each +/// Intrawave: Synchronous -- all waves in a workgroup synchronize after each /// k-iteration. Memory loads and compute are interleaved within a single wave. /// Two block_sync_lds() calls per iteration. /// -/// Interwave: Asynchronous — waves proceed independently with minimal +/// Interwave: Asynchronous -- waves proceed independently with minimal /// synchronization. Only one block_sync_lds() per iteration. Overlaps /// compute from one wave with memory loads from another. /// Only valid with Pipeline::Memory. @@ -127,7 +127,7 @@ enum class PipelineScheduler /// /// Direct: 2D grid with direct blockIdx mapping. /// Grid: (M/TileM) x (N/TileN) x k_batch. -/// Mapping: blockIdx.x → M tiles, blockIdx.y → N tiles. +/// Mapping: blockIdx.x -> M tiles, blockIdx.y -> N tiles. /// Uses GemmKernel. /// /// Linear: 1D grid with row-major linearized tile indexing (default). @@ -174,7 +174,7 @@ struct Dim3 }; /// Algorithm: describes HOW a GEMM executes (tile geometry, partitioning). -/// Independent of data types — paired with Signature in makeSpec(). +/// Independent of data types -- paired with Signature in makeSpec(). struct GemmAlgorithm { Dim3 block_tile; // Elements per workgroup {M, N, K} @@ -195,31 +195,31 @@ struct GemmAlgorithm /// When true, kernel handles boundaries with bounds checks. /// /// Note: K must always be divisible by block_tile.k. CK Tile's kPadK flag only - /// controls vector load width (scalar vs vectorized) — it does NOT mask the K-tail. + /// controls vector load width (scalar vs vectorized) -- it does NOT mask the K-tail. /// Passing non-aligned K produces silently wrong results. bool pad_m = false; bool pad_n = false; }; // ============================================================================ -// GemmSpec — structural NTTP for template instantiation +// GemmSpec -- structural NTTP for template instantiation // ============================================================================ /// Validated kernel descriptor with all types, layouts, and tile geometry resolved. /// All members are structural types (enums, ints, aggregates) so this works as NTTP. /// /// Physical tensor table layout (ordered by args_slot): -/// [0] = lhs (GEMM left operand — name is user-chosen, e.g., "A", "Q") -/// [1] = rhs (GEMM right operand — name is user-chosen, e.g., "B", "K") -/// [2] = output (final output — name varies by epilogue chain) -/// [3] = D0 (optional — first auxiliary epilogue tensor, e.g., "bias") -/// [4] = D1 (optional — second auxiliary epilogue tensor) +/// [0] = lhs (GEMM left operand -- name is user-chosen, e.g., "A", "Q") +/// [1] = rhs (GEMM right operand -- name is user-chosen, e.g., "B", "K") +/// [2] = output (final output -- name varies by epilogue chain) +/// [3] = D0 (optional -- first auxiliary epilogue tensor, e.g., "bias") +/// [4] = D1 (optional -- second auxiliary epilogue tensor) /// /// "D tensor" is CK Tile's convention for auxiliary tensors that participate /// in the epilogue (bias, scale, residual) but are not GEMM operands. struct GemmSpec { - // Physical tensor table — the kernel's view of Args::tensors[] + // Physical tensor table -- the kernel's view of Args::tensors[] int num_physical_tensors; std::array physical_tensors; @@ -256,7 +256,7 @@ struct GemmSpec // Quantization group size (0 = not quantized, >0 = elements per group along K) int group_size; - /// Number of auxiliary D tensors (bias, etc.) — excludes scale tensor. + /// Number of auxiliary D tensors (bias, etc.) -- excludes scale tensor. /// Derived from the physical tensor table: total slots minus lhs/rhs/output minus scale. constexpr int numDTensors() const { @@ -282,11 +282,11 @@ struct GemmSpec /// Name varies by epilogue chain: "C" (plain), "D" (with combine), "E" (with activation). constexpr PhysicalTensor output() const { return physical_tensors[2]; } - /// First auxiliary tensor D0 (position 3 — e.g., bias for AddOp). + /// First auxiliary tensor D0 (position 3 -- e.g., bias for AddOp). /// Only valid when num_physical_tensors > 3. constexpr PhysicalTensor d0() const { return physical_tensors[3]; } - /// Second auxiliary tensor D1 (position 4 — e.g., second bias/scale). + /// Second auxiliary tensor D1 (position 4 -- e.g., second bias/scale). /// Only valid when num_physical_tensors > 4. constexpr PhysicalTensor d1() const { return physical_tensors[4]; } @@ -296,10 +296,10 @@ struct GemmSpec }; // ============================================================================ -// Named tensor accessors (consteval — compile-time only) +// Named tensor accessors (consteval -- compile-time only) // ============================================================================ -/// Lookup a physical tensor by name. consteval — compile-time only. +/// Lookup a physical tensor by name. consteval -- compile-time only. /// Used in static_asserts and consteval makeSpec() result inspection. /// For runtime access, use GemmSpec::output() or physical_tensors[] directly. consteval PhysicalTensor tensor(const GemmSpec& k, std::string_view name) @@ -310,13 +310,13 @@ consteval PhysicalTensor tensor(const GemmSpec& k, std::string_view name) throw "tensor is not a physical slot in this kernel"; } -/// Slot index lookup by name. consteval — compile-time only. +/// Slot index lookup by name. consteval -- compile-time only. consteval int slot(const GemmSpec& k, std::string_view name) { return tensor(k, name).args_slot; } -/// Dtype lookup by name. consteval — compile-time only. +/// Dtype lookup by name. consteval -- compile-time only. consteval DataType dtype(const GemmSpec& k, std::string_view name) { return tensor(k, name).dtype; } -/// Layout lookup by name. consteval — compile-time only. +/// Layout lookup by name. consteval -- compile-time only. consteval Layout layout(const GemmSpec& k, std::string_view name) { return tensor(k, name).layout; } // ============================================================================ @@ -418,15 +418,15 @@ makeSpec(const Signature& sig, const GemmAlgorithm& algo, const TargetSet& targe bool is_i8 = (a_td.dtype == DataType::I8 || b_td.dtype == DataType::I8); if(is_i8 && acc != DataType::I32) - throw "INT8 GEMM requires I32 accumulator — set GemmOp::acc_dtype = DataType::I32"; + throw "INT8 GEMM requires I32 accumulator -- set GemmOp::acc_dtype = DataType::I32"; if(is_i8 && targets.contains(GpuTarget::gfx90a)) - throw "INT8 GEMM requires gfx942+ — gfx90a emulates int8 MFMA with float MFMA, " + throw "INT8 GEMM requires gfx942+ -- gfx90a emulates int8 MFMA with float MFMA, " "producing corrupted output. Use TargetSet::family_gfx94() or exclude gfx90a."; // INT4 rhs requires .quantize (no unquantized INT4 path exists in CK Tile) if(b_td.dtype == DataType::I4 && !b_td.quantize.has_value()) - throw "rhs dtype is I4 but Tensor.quantize is not set — " + throw "rhs dtype is I4 but Tensor.quantize is not set -- " "INT4 requires quantization metadata (scale tensor and group_size)"; // Build epilogue op chain from remaining ops after GemmOp. @@ -488,7 +488,7 @@ makeSpec(const Signature& sig, const GemmAlgorithm& algo, const TargetSet& targe // Direct2D epilogue does not support D tensors if(algo.store_strategy == StoreStrategy::Direct2D && num_d_tensors > 0) - throw "Direct2D epilogue does not support D tensors — use CShuffle or remove binary ops " + throw "Direct2D epilogue does not support D tensors -- use CShuffle or remove binary ops " "(Add/Mul)"; // Tile validation @@ -507,7 +507,7 @@ makeSpec(const Signature& sig, const GemmAlgorithm& algo, const TargetSet& targe // Pipeline-specific constraints if(is_i8 && algo.pipeline == Pipeline::V1) - throw "INT8 GEMM requires V3/V4/Memory pipeline — V1 does not support int8"; + throw "INT8 GEMM requires V3/V4/Memory pipeline -- V1 does not support int8"; if(algo.pipeline == Pipeline::Preshuffle) { diff --git a/rocm_ck/include/rocm_ck/gpu_target.hpp b/rocm_ck/include/rocm_ck/gpu_target.hpp index fa4230b249..35dcc72e67 100644 --- a/rocm_ck/include/rocm_ck/gpu_target.hpp +++ b/rocm_ck/include/rocm_ck/gpu_target.hpp @@ -19,7 +19,7 @@ enum class GpuTarget : uint8_t gfx1102, // RDNA 3 gfx1150, // RDNA 3.5 gfx1151, // RDNA 3.5 - _count // must be last — new targets go above this line + _count // must be last -- new targets go above this line }; } // namespace rocm_ck diff --git a/rocm_ck/include/rocm_ck/ops.hpp b/rocm_ck/include/rocm_ck/ops.hpp index 5472159961..9ac1d991fa 100644 --- a/rocm_ck/include/rocm_ck/ops.hpp +++ b/rocm_ck/include/rocm_ck/ops.hpp @@ -27,7 +27,7 @@ namespace rocm_ck { // Matrix multiplication: out = lhs x rhs. // Layout defaults (applied during resolve): lhs=Row, rhs=Col, out=Row. -// acc_dtype is the accumulation type — defaults to FP32, the universal safe +// acc_dtype is the accumulation type -- defaults to FP32, the universal safe // choice across all input types. struct GemmOp { diff --git a/rocm_ck/include/rocm_ck/resolve.hpp b/rocm_ck/include/rocm_ck/resolve.hpp index 080c9a0565..e6512ec73f 100644 --- a/rocm_ck/include/rocm_ck/resolve.hpp +++ b/rocm_ck/include/rocm_ck/resolve.hpp @@ -1,7 +1,7 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT // -// Role: meta — consteval resolve(), C++20 concepts. No runtime, no CK deps. +// Role: meta -- consteval resolve(), C++20 concepts. No runtime, no CK deps. // // Signature resolution: resolves a Signature into concrete tensor descriptors. // @@ -10,8 +10,8 @@ // Tensor entries, and applies the dtype cascade. All at compile time (consteval). // // Op dispatch uses C++20 concepts to classify ops by structural shape: -// BinaryOpLike — has {lhs, rhs, out} string_view members -// UnaryOpLike — has {in, out} string_view members +// BinaryOpLike -- has {lhs, rhs, out} string_view members +// UnaryOpLike -- has {in, out} string_view members // A single visitOp() function is the only place the Op variant type list // appears. Adding a new op requires one line in visitOp(); concepts handle // generic registration and propagation automatically. @@ -29,7 +29,7 @@ namespace rocm_ck { // ============================================================================ -// Op structural concepts — classify ops by their tensor slot shape +// Op structural concepts -- classify ops by their tensor slot shape // ============================================================================ /// Ops with three tensor slots: lhs, rhs, out (e.g., AddOp, MulOp). @@ -53,9 +53,9 @@ concept UnaryOpLike = requires(const T& t) { }; // ============================================================================ -// visitOp — single consteval dispatch point for all Op types +// visitOp -- single consteval dispatch point for all Op types // -// Adding a new Op alternative requires no changes here — std::visit +// Adding a new Op alternative requires no changes here -- std::visit // enforces exhaustiveness at compile time via the visitor's operator(). // ============================================================================ @@ -87,7 +87,7 @@ struct ResolvedSignature for(int i = 0; i < num_tensors; ++i) if(tensors[i].name == name) return tensors[i]; - throw "tensor() — name not found in resolved signature; " + throw "tensor() -- name not found in resolved signature; " "check that it appears in an op slot or Tensor entry"; } @@ -97,7 +97,7 @@ struct ResolvedSignature for(int i = 0; i < num_tensors; ++i) if(tensors[i].name == name) return i; - throw "tensorIndex() — name not found in resolved signature; " + throw "tensorIndex() -- name not found in resolved signature; " "check that it appears in an op slot or Tensor entry"; } @@ -107,7 +107,7 @@ struct ResolvedSignature for(int i = 0; i < num_scalars; ++i) if(scalars[i].name == name) return scalars[i]; - throw "scalar() — name not found; " + throw "scalar() -- name not found; " "add a Scalar entry with this name to the Signature"; } @@ -117,7 +117,7 @@ struct ResolvedSignature for(int i = 0; i < num_scalars; ++i) if(scalars[i].name == name) return i; - throw "scalarIndex() — name not found; " + throw "scalarIndex() -- name not found; " "add a Scalar entry with this name to the Signature"; } @@ -143,16 +143,16 @@ struct ResolvedSignature }; // ============================================================================ -// Op slot visitors — concept-driven, generic handling for binary/unary ops +// Op slot visitors -- concept-driven, generic handling for binary/unary ops // ============================================================================ /// Register all tensor slots of an op. Returns the output tensor name. /// /// Uses concepts for generic dispatch: -/// GemmOp — special case: sets operator-implied rank/layout defaults -/// ScaleOp — special case: validates scalar reference against sig.scalars[] -/// BinaryOpLike — generic: registers lhs, rhs, out -/// UnaryOpLike — generic: registers in, out +/// GemmOp -- special case: sets operator-implied rank/layout defaults +/// ScaleOp -- special case: validates scalar reference against sig.scalars[] +/// BinaryOpLike -- generic: registers lhs, rhs, out +/// UnaryOpLike -- generic: registers in, out /// /// Adding a new BinaryOpLike or UnaryOpLike op requires no changes here. consteval std::string_view collectTensorSlotsFromOp(const Op& op, @@ -192,7 +192,7 @@ consteval std::string_view collectTensorSlotsFromOp(const Op& op, } } if(!found_scalar) - throw "ScaleOp.scale references undeclared Scalar — " + throw "ScaleOp.scale references undeclared Scalar -- " "add a matching Scalar entry to the Signature"; return typed_op.out; } @@ -211,7 +211,7 @@ consteval std::string_view collectTensorSlotsFromOp(const Op& op, } else { - throw "unhandled Op type in collectTensorSlotsFromOp — " + throw "unhandled Op type in collectTensorSlotsFromOp -- " "add explicit handling or satisfy BinaryOpLike/UnaryOpLike"; } }); @@ -244,7 +244,7 @@ consteval void propagateRankLayout(const Op& op, auto& propagate_binary, auto& p } else { - throw "unhandled Op type in propagateRankLayout — " + throw "unhandled Op type in propagateRankLayout -- " "add explicit handling or satisfy BinaryOpLike/UnaryOpLike"; } }); @@ -318,7 +318,7 @@ consteval ResolvedSignature resolve(const Signature& sig) changed = true; } else if(infos[idx].rank != rank) - throw "conflicting rank for tensor — two operators imply different ranks; " + throw "conflicting rank for tensor -- two operators imply different ranks; " "check that shared tensor names are intentional"; } if(layout != Layout::Auto) @@ -329,7 +329,7 @@ consteval ResolvedSignature resolve(const Signature& sig) changed = true; } else if(infos[idx].layout != layout) - throw "conflicting layout for tensor — two operators imply different layouts; " + throw "conflicting layout for tensor -- two operators imply different layouts; " "check that shared tensor names are intentional"; } return changed; @@ -415,7 +415,7 @@ consteval ResolvedSignature resolve(const Signature& sig) break; } if(changed) - throw "could not infer rank/layout for all tensors — " + throw "could not infer rank/layout for all tensors -- " "set rank and layout explicitly on Tensor entries, " "or reduce chained operations"; @@ -490,7 +490,7 @@ consteval ResolvedSignature resolve(const Signature& sig) infos[i].quantize_info.group_size}; } - // Collect declared scalars (pass-through — no inference needed) + // Collect declared scalars (pass-through -- no inference needed) for(int i = 0; i < kMaxScalars; ++i) { if(sig.scalars[i].name.empty()) diff --git a/rocm_ck/include/rocm_ck/signature.hpp b/rocm_ck/include/rocm_ck/signature.hpp index 1e950d3e5b..62ca540a67 100644 --- a/rocm_ck/include/rocm_ck/signature.hpp +++ b/rocm_ck/include/rocm_ck/signature.hpp @@ -1,7 +1,7 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT // -// Role: meta — Signature, Tensor, Scalar. No runtime, no CK deps. +// Role: meta -- Signature, Tensor, Scalar. No runtime, no CK deps. // // Signature: the complete description of WHAT a kernel computes. // @@ -65,10 +65,10 @@ struct Scalar /// A directed compute graph where tensors are nodes and operators are edges. /// Each operator output gets a unique name; shared names form graph edges. /// -/// Example — simple fp16 GEMM: +/// Example -- simple fp16 GEMM: /// {.dtype = FP16, .ops = {GemmOp{.lhs="A", .rhs="B", .out="C"}}} /// -/// Example — GEMM + bias + ReLU: +/// Example -- GEMM + bias + ReLU: /// {.dtype = FP16, /// .ops = {GemmOp{.lhs="A", .rhs="B", .out="C"}, /// AddOp{.lhs="C", .rhs="bias", .out="D"}, diff --git a/rocm_ck/include/rocm_ck/spec_json.hpp b/rocm_ck/include/rocm_ck/spec_json.hpp index 6f8c58e373..30e1f22309 100644 --- a/rocm_ck/include/rocm_ck/spec_json.hpp +++ b/rocm_ck/include/rocm_ck/spec_json.hpp @@ -1,13 +1,13 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT // -// Role: host — JSON serialization for spec types. No CK deps. +// Role: host -- JSON serialization for spec types. No CK deps. // // Runtime to_json() functions for GemmSpec. // Used by build-time extractors to emit .spec.json files that pack.py // reads to embed structured metadata in the kpack TOC. // -// Hand-written JSON — no library dependency. The schema is fixed and small. +// Hand-written JSON -- no library dependency. The schema is fixed and small. #pragma once diff --git a/rocm_ck/include/rocm_ck/validate.hpp b/rocm_ck/include/rocm_ck/validate.hpp index 86f94f48b5..3cee3e7d93 100644 --- a/rocm_ck/include/rocm_ck/validate.hpp +++ b/rocm_ck/include/rocm_ck/validate.hpp @@ -1,7 +1,7 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT // -// Role: host — debug-only runtime validation. No CK deps. +// Role: host -- debug-only runtime validation. No CK deps. // // Validates Args against a spec's physical tensor table before kernel launch. // Catches "forgot to fill a tensor slot" at launch time instead of silent GPU diff --git a/rocm_ck/tests/compile_fail/conflicting_layout.cpp b/rocm_ck/tests/compile_fail/conflicting_layout.cpp index fbae32317c..94b22764ac 100644 --- a/rocm_ck/tests/compile_fail/conflicting_layout.cpp +++ b/rocm_ck/tests/compile_fail/conflicting_layout.cpp @@ -3,8 +3,8 @@ // // Must fail: two GemmOps imply conflicting layouts for the same tensor. // -// GemmOp1 outputs "C" → implied Row layout. -// GemmOp2 uses "C" as rhs → implied Col layout. +// GemmOp1 outputs "C" -> implied Row layout. +// GemmOp2 uses "C" as rhs -> implied Col layout. // These conflict: "C" can't be both Row and Col. // // Expected error: "conflicting layout for tensor" diff --git a/rocm_ck/tests/unit/unit_arch_properties.cpp b/rocm_ck/tests/unit/unit_arch_properties.cpp index 7d45be7c27..d1b8b9247a 100644 --- a/rocm_ck/tests/unit/unit_arch_properties.cpp +++ b/rocm_ck/tests/unit/unit_arch_properties.cpp @@ -93,7 +93,7 @@ TEST(WavefrontSize, MatchesTargetProperties) } // ============================================================================ -// TargetSet — construction +// TargetSet -- construction // ============================================================================ TEST(TargetSet, DefaultConstructsToEmpty) @@ -112,7 +112,7 @@ TEST(TargetSet, ImplicitConversionFromSingleTarget) } // ============================================================================ -// TargetSet — named constructors +// TargetSet -- named constructors // ============================================================================ TEST(TargetSet, AllContainsEveryTarget) @@ -204,7 +204,7 @@ TEST(TargetSet, OnlyWithThreeTargets) // variadic arity of TargetSet::only() overloads (1, 2, 3 parameters). // ============================================================================ -// TargetSet — set operations +// TargetSet -- set operations // ============================================================================ TEST(TargetSet, ExcludingRemovesOneTarget) @@ -271,7 +271,7 @@ TEST(TargetSet, EmptyIntersectIsEmpty) } // ============================================================================ -// TargetSet — operators +// TargetSet -- operators // ============================================================================ TEST(TargetSet, OperatorOrDelegatesToUnion) @@ -299,7 +299,7 @@ TEST(TargetSet, InequalityDifferentSetsReturnsTrue) } // ============================================================================ -// TargetSet — queries +// TargetSet -- queries // ============================================================================ TEST(TargetSet, ContainsReturnsTrueForMember) @@ -362,7 +362,7 @@ TEST(TargetSet, IsAllRdnaPredicateWorks) } // ============================================================================ -// TargetSet — iteration +// TargetSet -- iteration // ============================================================================ TEST(TargetSet, ForEachIteratesAllTargets) @@ -380,7 +380,7 @@ TEST(TargetSet, ForEachOnEmptySetDoesNothing) } // ============================================================================ -// isValidWaveTile — single target +// isValidWaveTile -- single target // ============================================================================ TEST(IsValidWaveTile, FP32MFMATilesOnCDNA) @@ -431,7 +431,7 @@ TEST(IsValidWaveTile, FP8FNUZBaseTilesOnGfx942) TEST(IsValidWaveTile, FP8FNUZIterateKTilesOnGfx942Plus) { - // IterateK compositions of base FP8 MFMA — available on gfx942+ + // IterateK compositions of base FP8 MFMA -- available on gfx942+ EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 32, GpuTarget::gfx942)); EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 64, GpuTarget::gfx942)); EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 64, GpuTarget::gfx942)); @@ -444,7 +444,7 @@ TEST(IsValidWaveTile, FP8FNUZIterateKTilesOnGfx942Plus) TEST(IsValidWaveTile, WMMATilesOnRDNA) { - // All RDNA targets share identical WMMA: 16×16×16 for FP16, BF16, INT8 + // All RDNA targets share identical WMMA: 16x16x16 for FP16, BF16, INT8 EXPECT_TRUE(isValidWaveTile(DataType::FP16, 16, 16, 16, GpuTarget::gfx1100)); EXPECT_TRUE(isValidWaveTile(DataType::BF16, 16, 16, 16, GpuTarget::gfx1100)); EXPECT_TRUE(isValidWaveTile(DataType::I8, 16, 16, 16, GpuTarget::gfx1100)); @@ -469,7 +469,7 @@ TEST(IsValidWaveTile, WMMARejectsFP32) } // ============================================================================ -// isValidWaveTile — TargetSet (intersection semantics) +// isValidWaveTile -- TargetSet (intersection semantics) // ============================================================================ TEST(IsValidWaveTile, IntersectionAcrossCDNATargets) diff --git a/rocm_ck/tests/unit/unit_gemm_spec.cpp b/rocm_ck/tests/unit/unit_gemm_spec.cpp index c914d9d433..c6e9abdbb8 100644 --- a/rocm_ck/tests/unit/unit_gemm_spec.cpp +++ b/rocm_ck/tests/unit/unit_gemm_spec.cpp @@ -444,30 +444,30 @@ TEST(WaveTileValidation, TargetSetAllMeansIntersectionAcrossAllTargets) // Only 16x16x16 FP16/BF16 pass (valid on both MFMA and WMMA). EXPECT_TRUE(isValidWaveTile(DataType::FP16, 16, 16, 16, TargetSet::all())); EXPECT_TRUE(isValidWaveTile(DataType::BF16, 16, 16, 16, TargetSet::all())); - // I8 16x16x16 fails — CDNA MFMA I8 tiles are 32x32x16 and 16x16x32, not 16x16x16 + // I8 16x16x16 fails -- CDNA MFMA I8 tiles are 32x32x16 and 16x16x32, not 16x16x16 EXPECT_FALSE(isValidWaveTile(DataType::I8, 16, 16, 16, TargetSet::all())); - // 32x32 tiles fail — WMMA only has 16x16x16 + // 32x32 tiles fail -- WMMA only has 16x16x16 EXPECT_FALSE(isValidWaveTile(DataType::FP16, 32, 32, 16, TargetSet::all())); - // FP8 fails — gfx90a has no FP8, gfx1151 has no FP8 + // FP8 fails -- gfx90a has no FP8, gfx1151 has no FP8 EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, TargetSet::all())); EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 32, TargetSet::all())); - // FP32 fails — WMMA doesn't support FP32 + // FP32 fails -- WMMA doesn't support FP32 EXPECT_FALSE(isValidWaveTile(DataType::FP32, 16, 16, 4, TargetSet::all())); } TEST(WaveTileValidation, TargetSetCdnaRejectsFP8BecauseGfx90a) { - // cdna() includes gfx90a which has no FP8 — intersection rejects all FP8 tiles + // cdna() includes gfx90a which has no FP8 -- intersection rejects all FP8 tiles EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, TargetSet::cdna())); EXPECT_FALSE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 64, TargetSet::cdna())); } TEST(WaveTileValidation, TargetSetGfx94AcceptsFP8) { - // family_gfx94() = gfx942 + gfx950 — both support FP8 + // family_gfx94() = gfx942 + gfx950 -- both support FP8 EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 32, 32, 16, TargetSet::family_gfx94())); EXPECT_TRUE(isValidWaveTile(DataType::FP8_FNUZ, 16, 16, 32, TargetSet::family_gfx94())); // IterateK compositions valid across gfx94 family @@ -706,7 +706,7 @@ TEST(MakeSpec, QuantizedGemmHasZeroDTensors) GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}}, TargetSet::cdna()); - // Scale is NOT a D tensor — num_d_tensors excludes it + // Scale is NOT a D tensor -- num_d_tensors excludes it EXPECT_EQ(k.numDTensors(), 0); EXPECT_EQ(k.num_physical_tensors, 4); // A, B, C, scale } @@ -723,7 +723,7 @@ TEST(MakeSpec, QuantizedGemmAddHasOneDTensor) GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}}, TargetSet::cdna()); - // bias is D0, scale is separate — num_d_tensors counts only bias + // bias is D0, scale is separate -- num_d_tensors counts only bias EXPECT_EQ(k.numDTensors(), 1); EXPECT_EQ(k.num_physical_tensors, 5); // A, B, D, bias, scale } @@ -944,7 +944,7 @@ TEST(MakeSpec, QuantizedGemmWithMultipleEpilogueOps) } // ============================================================================ -// makeSpec: two consecutive AddOps (Add+Add → 2 D tensors) +// makeSpec: two consecutive AddOps (Add+Add -> 2 D tensors) // ============================================================================ TEST(MakeSpec, TwoConsecutiveAddOpsProduceTwoDTensors) @@ -979,7 +979,7 @@ TEST(MakeSpec, AcceptsMaxEpilogueOps) GemmAlgorithm{{128, 128, 32}, {2, 2, 1}, {16, 16, 16}}, TargetSet::cdna()); - // 2 epilogue ops (Add + Relu) — well under the limit of 4 + // 2 epilogue ops (Add + Relu) -- well under the limit of 4 EXPECT_EQ(k.num_epilogue_ops, 2); EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Add)); EXPECT_TRUE(k.hasEpilogueOp(EpilogueOp::Relu)); diff --git a/rocm_ck/tests/unit/unit_resolve.cpp b/rocm_ck/tests/unit/unit_resolve.cpp index 2b90f36596..f8a361fd0d 100644 --- a/rocm_ck/tests/unit/unit_resolve.cpp +++ b/rocm_ck/tests/unit/unit_resolve.cpp @@ -122,7 +122,7 @@ TEST(Resolve, AllowsPerTensorRankOverride) TEST(Resolve, AllowsPerTensorLayoutOverride) { - // Override B from default Col to Row (R×R layout) + // Override B from default Col to Row (RxR layout) constexpr auto r = resolve( // Signature{.dtype = DataType::FP16, .tensors = {Tensor{.name = "B", .layout = Layout::Row}}, @@ -135,7 +135,7 @@ TEST(Resolve, AllowsPerTensorLayoutOverride) TEST(Resolve, AllowsMultipleLayoutOverrides) { - // Override both A and B (C×C layout) + // Override both A and B (CxC layout) constexpr auto r = resolve( // Signature{.dtype = DataType::FP16, .tensors = {Tensor{.name = "A", .layout = Layout::Col}, @@ -181,9 +181,9 @@ TEST(Resolve, PropagatesRankAndLayoutThroughEpilogueChain) TEST(Resolve, PropagatesRankAndLayoutThroughDiamondDAG) { - // Diamond: GEMM→C splits into two Add paths, then joins. - // C → Add(C,bias1)→D1 ─→ Add(D1,D2)→E - // C → Add(C,bias2)→D2 ─┘ + // Diamond: GEMM->C splits into two Add paths, then joins. + // C -> Add(C,bias1)->D1 --> Add(D1,D2)->E + // C -> Add(C,bias2)->D2 -+ constexpr auto r = resolve( // Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}, @@ -230,13 +230,13 @@ TEST(Resolve, ResolvesStandaloneAddWithoutImpliedRank) } // ============================================================================ -// Conflict detection — redundant identical sets are silent +// Conflict detection -- redundant identical sets are silent // ============================================================================ TEST(Resolve, AllowsRedundantIdenticalLayoutFromTwoGemmOps) { // GemmOp1 outputs "C" as Row. GemmOp2 uses "C" as lhs (also Row). - // Two ops set the same layout → no conflict. + // Two ops set the same layout -> no conflict. constexpr auto r = resolve( // Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}, @@ -249,7 +249,7 @@ TEST(Resolve, AllowsRedundantIdenticalLayoutFromTwoGemmOps) TEST(Resolve, AllowsPropagationThroughAddWithConsistentLayout) { // GemmOp sets C=Row. AddOp connects C to bias and D. - // Propagation sets bias and D to Row (matching C) → no conflict. + // Propagation sets bias and D to Row (matching C) -> no conflict. constexpr auto r = resolve( // Signature{.dtype = DataType::FP16, .ops = {GemmOp{.lhs = "A", .rhs = "B", .out = "C"}, @@ -307,7 +307,7 @@ TEST(Resolve, ReportsZeroScalarsWhenNoneDeclared) } // ============================================================================ -// findTensor / findScalar (constexpr, not consteval — returns -1 on miss) +// findTensor / findScalar (constexpr, not consteval -- returns -1 on miss) // ============================================================================ TEST(Resolve, FindsTensorByName) diff --git a/rocm_ck/tests/unit/unit_schema_compatibility.cpp b/rocm_ck/tests/unit/unit_schema_compatibility.cpp index 2d31c75cb7..4052071f1e 100644 --- a/rocm_ck/tests/unit/unit_schema_compatibility.cpp +++ b/rocm_ck/tests/unit/unit_schema_compatibility.cpp @@ -28,7 +28,7 @@ using ::rocm_ck::Signature; using ::rocm_ck::TargetSet; // Frozen baseline tests: these assert ALL fields of each spec variant. -// This is intentionally brittle — adding a new field to GemmSpec will +// This is intentionally brittle -- adding a new field to GemmSpec will // break these tests, forcing explicit review of the change's impact on // existing variants. Update the expected values when making intentional // schema changes. diff --git a/rocm_ck/tests/unit/unit_validate.cpp b/rocm_ck/tests/unit/unit_validate.cpp index bf91bfce9f..4d91133204 100644 --- a/rocm_ck/tests/unit/unit_validate.cpp +++ b/rocm_ck/tests/unit/unit_validate.cpp @@ -36,7 +36,7 @@ static constexpr auto test_spec_d0 = TargetSet::cdna()); // ============================================================================ -// validate() — passes when all tensors are filled +// validate() -- passes when all tensors are filled // ============================================================================ TEST(Validate, PassesWhenAllTensorsFilled) @@ -69,7 +69,7 @@ TEST(Validate, PassesWithD0TensorFilled) } // ============================================================================ -// validate() — aborts when a tensor is missing +// validate() -- aborts when a tensor is missing // ============================================================================ #ifndef NDEBUG @@ -101,7 +101,7 @@ TEST(ValidateDeathTest, AbortsOnMissingD0Tensor) TEST(ValidateDeathTest, ReportsFirstMissingTensor) { - // All slots null — should report the first one (lhs = "A", slot 0) + // All slots null -- should report the first one (lhs = "A", slot 0) Args args{}; EXPECT_DEATH(validate(args, test_spec), "tensor \"A\" \\(slot 0\\) has null pointer"); diff --git a/script/check_ascii_only.sh b/script/check_ascii_only.sh new file mode 100755 index 0000000000..0a0062d97d --- /dev/null +++ b/script/check_ascii_only.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Rejects any byte outside printable ASCII (plus tab, LF, CR) in the +# files passed as arguments. Used both by the local pre-commit hook +# and by the Jenkinsfile "ASCII Only Check" static-check stage. +# +# Usage: ./check_ascii_only.sh ... + +exit_code=0 + +for file in "$@"; do + [[ -f "$file" ]] || continue + if LC_ALL=C grep -qP '[^\x09\x0A\x0D\x20-\x7E]' "$file" 2>/dev/null; then + echo "ERROR: $file contains non-ASCII bytes:" + LC_ALL=C grep -nP '[^\x09\x0A\x0D\x20-\x7E]' "$file" | head -20 + echo " Fix: replace with ASCII (em-dash -> --, smart quotes -> \", arrows -> ->, etc.)" + exit_code=1 + fi +done + +exit $exit_code diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp index 57d7441b17..c4fc1415cb 100644 --- a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp @@ -244,7 +244,7 @@ void sparse_transform_verify( } // Semantic index validation: each 2-bit field in h_idx encodes the original - // slot (0–3) within the group of 4 that the corresponding compressed element + // slot (0-3) within the group of 4 that the corresponding compressed element // came from. Verify that the index is consistent with input and output. // // Note: when a group has fewer than 2 non-zeros, unused output slots contain diff --git a/test/ck_tile/data_type/test_bf16.cpp b/test/ck_tile/data_type/test_bf16.cpp index d1de2da438..81ae1ec7a5 100644 --- a/test/ck_tile/data_type/test_bf16.cpp +++ b/test/ck_tile/data_type/test_bf16.cpp @@ -426,7 +426,7 @@ TEST_F(Bf16ConversionTest, DenormalHandling) TEST_F(Bf16ConversionTest, OverflowHandling) { // Note: BF16 has the same 8-bit exponent as float32, but only 7 mantissa bits vs 23. - // This means bf16::max (0x7F7F) ≈ 3.39e38 is LESS than float::max ≈ 3.40e38. + // This means bf16::max (0x7F7F) ~= 3.39e38 is LESS than float::max ~= 3.40e38. // // Hardware behavior differs by architecture: // - gfx950: RTN rounding -> float::max rounds to infinity (IEEE-754 compliant) @@ -1611,7 +1611,7 @@ TEST_F(Bf16PlatformTest, DevicePerformance) std::cout << "\n=== Performance Results ===" << std::endl; std::cout << "Float to BF16 conversion:" << std::endl; - std::cout << " Average time: " << avg_time << " μs" << std::endl; + std::cout << " Average time: " << avg_time << " us" << std::endl; std::cout << " Throughput: " << throughput << " GB/s" << std::endl; // Initialize bf16 data for arithmetic test @@ -1632,7 +1632,7 @@ TEST_F(Bf16PlatformTest, DevicePerformance) throughput = (3 * n * sizeof(bf16_t)) / (avg_time * 1e6); // GB/s std::cout << "\nBF16 arithmetic operations:" << std::endl; - std::cout << " Average time: " << avg_time << " μs" << std::endl; + std::cout << " Average time: " << avg_time << " us" << std::endl; std::cout << " Throughput: " << throughput << " GB/s" << std::endl; std::cout << "===========================" << std::endl; diff --git a/test/ck_tile/data_type/test_mx_scale.cpp b/test/ck_tile/data_type/test_mx_scale.cpp index 2a7d820c7e..c37a979fc6 100644 --- a/test/ck_tile/data_type/test_mx_scale.cpp +++ b/test/ck_tile/data_type/test_mx_scale.cpp @@ -489,7 +489,7 @@ void test_f8_pkscale_type_convert_device() // Option 0: Fixed pattern with wide dynamic range (same for all rows) for(int m = 0; m < M; m++) { - fscale[m * N_scale + 0] = std::pow(2.0f, -10.0f); // 2^-10 ≈ 0.000977 + fscale[m * N_scale + 0] = std::pow(2.0f, -10.0f); // 2^-10 ~= 0.000977 fscale[m * N_scale + 1] = std::pow(2.0f, -5.0f); // 2^-5 = 0.03125 fscale[m * N_scale + 2] = std::pow(2.0f, 8.0f); // 2^8 = 256 fscale[m * N_scale + 3] = std::pow(2.0f, 16.0f); // 2^16 = 65536 diff --git a/test/ck_tile/data_type/test_pk_fp4.cpp b/test/ck_tile/data_type/test_pk_fp4.cpp index 8c5084dd20..07f3dc1ace 100644 --- a/test/ck_tile/data_type/test_pk_fp4.cpp +++ b/test/ck_tile/data_type/test_pk_fp4.cpp @@ -384,7 +384,7 @@ void test_pkscale_type_convert_device() // Option 0: Fixed pattern with wide dynamic range (same for all rows) for(int m = 0; m < M; m++) { - fscale[m * N_scale + 0] = std::pow(2.0f, -10.0f); // 2^-10 ≈ 0.000977 + fscale[m * N_scale + 0] = std::pow(2.0f, -10.0f); // 2^-10 ~= 0.000977 fscale[m * N_scale + 1] = std::pow(2.0f, -5.0f); // 2^-5 = 0.03125 fscale[m * N_scale + 2] = std::pow(2.0f, 8.0f); // 2^8 = 256 fscale[m * N_scale + 3] = std::pow(2.0f, 16.0f); // 2^16 = 65536 diff --git a/test/ck_tile/data_type/test_pk_fp6.cpp b/test/ck_tile/data_type/test_pk_fp6.cpp index 7d852ce8d5..5d962325ab 100644 --- a/test/ck_tile/data_type/test_pk_fp6.cpp +++ b/test/ck_tile/data_type/test_pk_fp6.cpp @@ -703,10 +703,10 @@ void test_pkscale_type_convert_device() if(scale_init_option == 0) { // Option 0: Fixed pattern with wide dynamic range (same for all rows) - // Note: Values chosen to be safe for fp16 (max ≈ 65504) + // Note: Values chosen to be safe for fp16 (max ~= 65504) for(int m = 0; m < M; m++) { - fscale[m * N_scale + 0] = std::pow(2.0f, -10.0f); // 2^-10 ≈ 0.000977 + fscale[m * N_scale + 0] = std::pow(2.0f, -10.0f); // 2^-10 ~= 0.000977 fscale[m * N_scale + 1] = std::pow(2.0f, -5.0f); // 2^-5 = 0.03125 fscale[m * N_scale + 2] = std::pow(2.0f, 8.0f); // 2^8 = 256 fscale[m * N_scale + 3] = std::pow(2.0f, 15.0f); // 2^15 = 32768 (safe for fp16) diff --git a/test/ck_tile/fmha/test_fmha_bwd.cpp b/test/ck_tile/fmha/test_fmha_bwd.cpp index 8d90ad9143..632e179dd1 100644 --- a/test/ck_tile/fmha/test_fmha_bwd.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd.cpp @@ -658,7 +658,7 @@ INSTANTIATE_TEST_SUITE_P( ValuesIn([]() { std::vector test_cases; - // Minimal waste: ~1-5% padding (logical ≈ physical - small delta) + // Minimal waste: ~1-5% padding (logical ~= physical - small delta) test_cases.push_back( std::tuple{2, 2, 2, 127, 127, "0"}); // Q:127->128 (~0.8%), K:127->128 test_cases.push_back( diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_fwd.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_fwd.cpp index d9ad9559bf..0574953198 100644 --- a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_fwd.cpp +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_fwd.cpp @@ -223,7 +223,7 @@ TEST_F(GroupedConvFwdIsSupportedArgumentTest, SplitImageFullImageAfterMakeKernel 2, true /*EnableSplitImage*/>::type; - // K=64, C=64 — MakeKernelArgs should set up a single-piece split_image. + // K=64, C=64 - MakeKernelArgs should set up a single-piece split_image. // 3x3 filter, stride 1, no padding => output H/W = 5x5. auto host_args = create_2d_fwd_host_args(1, 2, 64, 64, 3, 3, 7, 7); auto kargs = Kernel::MakeKernelArgs(host_args); @@ -280,7 +280,7 @@ TEST_F(GroupedConvFwdIsSupportedArgumentTest, SplitImageFullImageLargeK) 2, true /*EnableSplitImage*/>::type; - // K=96 — the case that caused flaky failures + // K=96 - the case that caused flaky failures // 1x1 filter, stride 1, no padding => output H/W = 73x128. auto host_args = create_2d_fwd_host_args(3, 5, 96, 200, 1, 1, 73, 128); auto kargs = Kernel::MakeKernelArgs(host_args); diff --git a/test/ck_tile/multicast_load/test_cluster_load_async_to_lds.cpp b/test/ck_tile/multicast_load/test_cluster_load_async_to_lds.cpp index 1519597219..cd3ba7702c 100644 --- a/test/ck_tile/multicast_load/test_cluster_load_async_to_lds.cpp +++ b/test/ck_tile/multicast_load/test_cluster_load_async_to_lds.cpp @@ -154,12 +154,12 @@ TEST(AsyncLDS, B128_SingleWGP) // 4 waves per WG (128 threads). Wave 0 issues the async cluster load into // LDS[0..31], then all waves synchronize via block_sync_lds_direct_load // (which waits ASYNCcnt=0 then does s_barrier_signal/wait). -// Waves 1–3 read from the same LDS buffer after the barrier. +// Waves 1-3 read from the same LDS buffer after the barrier. // Verifies the core guarantee: non-requesting waves see correct LDS data. // // block_sync_lds_direct_load<0>() is used for all waves: // - wave 0: asynccnt may be non-zero; it waits before signaling the barrier -// - waves 1–3: asynccnt is already 0 (no-op), then they signal and wait +// - waves 1-3: asynccnt is already 0 (no-op), then they signal and wait // The barrier ensures LDS writes from wave 0 are visible to all waves before // any wave reads from LDS. @@ -184,7 +184,7 @@ struct LDSVisibilityKernel } // All waves call block_sync_lds_direct_load: it issues s_wait_asynccnt (a - // no-op for waves 1–3 whose count is already 0), then s_barrier_signal/wait. + // no-op for waves 1-3 whose count is already 0), then s_barrier_signal/wait. // After this call all waves are past the barrier and LDS is safe to read. ck_tile::block_sync_lds_direct_load<0>(); @@ -827,7 +827,7 @@ TEST(PartialBroadcast, B32_4WGP_Mask0x5) // --------------------------------------------------------------------------- // 4 WGPs in a cluster, 4 waves per WG (128 threads). Wave 0 of each WG issues // cluster_multicast_load_async_to_lds (true broadcast: all lanes load from the -// same source address). After block_sync_lds_direct_load, waves 1–3 read from +// same source address). After block_sync_lds_direct_load, waves 1-3 read from // the same LDS buffer and write to global for host verification. // // This is the canonical GEMM prefetch pattern: @@ -836,12 +836,12 @@ TEST(PartialBroadcast, B32_4WGP_Mask0x5) // // Groups 2 and 4 test LDS visibility and multi-WGP broadcast in isolation; // this group tests the combination. A bug where the barrier doesn't fence -// the async LDS write from wave 0 before waves 1–3 read would appear here +// the async LDS write from wave 0 before waves 1-3 read would appear here // but not in Groups 2 or 4 individually. // // Verification: // - Wave 0: each lane loaded src_val -> lds_buf[lane] (confirmed via dst) -// - Waves 1–3: each lane read lds_buf[lane_id] = src_val (cross-wave visibility) +// - Waves 1-3: each lane read lds_buf[lane_id] = src_val (cross-wave visibility) // - All WGPs: same src_val in every LDS slot (multi-WGP broadcast) template @@ -875,7 +875,7 @@ struct MultiWGPLDSVisibilityKernel } // All waves call block_sync_lds_direct_load: it issues s_wait_asynccnt - // (a no-op for waves 1–3 whose count is already 0), then + // (a no-op for waves 1-3 whose count is already 0), then // s_barrier_signal/wait. Barrier ensures LDS is visible to all waves // before any wave reads from lds_buf. ck_tile::block_sync_lds_direct_load<0>(); @@ -927,7 +927,7 @@ void run_multiwgp_lds_visibility_test(int num_wgs, const T& src_val, const char* EXPECT_EQ(h_diag_ids[i], i) << test_name << ": blockIdx.x=" << i << " expected flat_id=" << i << " got " << h_diag_ids[i]; - // Every thread in every WGP must read src_val from LDS (waves 0–3, all WGPs). + // Every thread in every WGP must read src_val from LDS (waves 0-3, all WGPs). for(int wgp = 0; wgp < num_wgs; wgp++) { for(int wave = 0; wave < 4; wave++) diff --git a/test/data_type/test_bhalf.cpp b/test/data_type/test_bhalf.cpp index a933d94a33..ef3d58e04b 100644 --- a/test/data_type/test_bhalf.cpp +++ b/test/data_type/test_bhalf.cpp @@ -349,7 +349,7 @@ TYPED_TEST(BhalfConvertTest, F16RoundTrip) {0x3880u, f16_minnorm}, // fp16 min normal (2^-14) -> bf16 exact {0x4780u, f16_max, false, true}, // fp16 max (65504) -> bf16 65536 -> fp16 +inf // normal values spanning fp16 range, exact in both fp16 and bf16 - {0xBA80u, 0x9400u}, // -2^-10 ≈ -9.77e-4 + {0xBA80u, 0x9400u}, // -2^-10 ~= -9.77e-4 {0x3C80u, 0x2400u}, // 2^-6 = 0.015625 {0xC060u, 0xC300u}, // -3.5 {0x3F80u, 0x3C00u}, // 1.0 diff --git a/test/s_prefetch_inst_op/s_prefetch_inst_op_util.hpp b/test/s_prefetch_inst_op/s_prefetch_inst_op_util.hpp index d8100e0f80..8fd40dfed6 100644 --- a/test/s_prefetch_inst_op/s_prefetch_inst_op_util.hpp +++ b/test/s_prefetch_inst_op/s_prefetch_inst_op_util.hpp @@ -1,217 +1,217 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/host_utility/hip_check_error.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/flush_cache.hpp" - -#include "ck_tile/core/arch/inst_prefetch.hpp" - -#include - -namespace ck { -namespace s_prefetch_inst_op_util { - -template -struct KernelArgs -{ - const T* p_a_grid; - T* dst; - const T* p_b_grid; - uint32_t num_iters; -}; - -// --------------------------------------------------------------------------- -// A simple kernel that exercises INST_PREFETCH / INST_PREFETCH_TARGET macros. -// -// The kernel does: dst[tid] = src[tid] + scalar_sum -// -// Between the prefetch site and the target we place a deliberate computation -// loop so that the prefetched instruction cache lines have time to arrive. -// Correctness does not depend on prefetching — it is pure performance hint. -// We verify correctness to ensure the asm volatile markers do not break -// code generation. -// --------------------------------------------------------------------------- - -template -__global__ void kernel_with_inst_prefetch(KernelArgs args) -{ - if constexpr(prefetch_inst_on) - { - enable_scalar_prefetch(); - // Prefetch the tail section of this kernel into L1I. - // We try to load 32 cachelines but gets clamped to smaller number inside if needed, to not - // go oob - INST_PREFETCH(INST_TEST_TAIL, 32); - } - - __builtin_amdgcn_sched_barrier(0); - - const T* src = args.p_a_grid; - T* dst = args.dst; - uint32_t num_iters = args.num_iters; - - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - - T sum = 0; - - if(tid < NUM_THREADS) - { - sum += src[tid]; - } - - __builtin_amdgcn_sched_barrier(0); - - // Hot loop — the PLACE target sits after it. - for(uint32_t iter = 0; iter < num_iters; ++iter) - { -#pragma unroll NUM_SCALARS - for(uint32_t i = 0; i < NUM_SCALARS; ++i) - { - sum += 1; - } - } - - __builtin_amdgcn_sched_barrier(0); - - INST_PREFETCH_TARGET(INST_TEST_TAIL, CK_PLACE_MODE_BLOCK_ENTRY); - -// Tail section (the code we prefetched). -#pragma unroll NUM_THREADS - for(uint32_t i = 0; i < NUM_THREADS; ++i) - { - sum += src[0]; - } - - if(tid < NUM_THREADS) - { - dst[tid] = sum; - } -} - -template -bool test_inst_prefetch_impl(bool time_kernels, const std::string& test_name) -{ - constexpr index_t num_elements = NUM_THREADS; - constexpr index_t num_scalars = 1; - constexpr index_t num_scalar_additions = NUM_SCALARS; - constexpr index_t block_size = 256; - constexpr index_t grid_size = (NUM_THREADS + block_size - 1) / block_size; - constexpr uint32_t num_iters = - 0; // always 0 so we jump to the prefetch target right after the barrier. - - std::cout << "Testing " << test_name << " for type: " << typeid(T).name() << std::endl; - std::cout << "Elements: " << num_elements - << ", Scalar additions(the one we jump over): " << num_scalar_additions << std::endl; - - // Host data - std::vector h_src(num_elements); - std::vector h_scalar(num_scalars); - std::vector h_dst(NUM_THREADS); - std::vector h_expected(NUM_THREADS); - - for(index_t i = 0; i < num_elements; i++) - { - h_src[i] = static_cast((i % 100) + 1); - } - - for(index_t i = 0; i < static_cast(NUM_THREADS); i++) - { - h_expected[i] = h_src[i] + h_src[0] * NUM_THREADS; - } - - DeviceMem d_src(sizeof(T) * num_elements); - DeviceMem d_scalar(sizeof(T) * num_scalars); - DeviceMem d_dst(sizeof(T) * NUM_THREADS); - - d_src.ToDevice(h_src.data()); - d_scalar.ToDevice(h_scalar.data()); - - KernelArgs args{static_cast(d_src.GetDeviceBuffer()), - static_cast(d_dst.GetDeviceBuffer()), - static_cast(d_scalar.GetDeviceBuffer()), - num_iters}; - - auto run_single = [&](auto kernel_fn, const std::string& label) -> std::pair { - float avg_us = 0; - - if(time_kernels) - { - constexpr int num_warmup = 10; - constexpr int num_iterations = 50; - constexpr int rotating_count = num_iterations; - auto size_a_buffer = d_src.GetBufferSize(); - auto size_b_buffer = d_scalar.GetBufferSize(); - - ck::utility::RotatingMemWrapper> rotating_mem( - args, rotating_count, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - ck::utility::flush_icache(); - rotating_mem.Next(); - }; - float avg_time_ms = ck::utility::launch_and_time_kernel_with_preprocess( - StreamConfig{nullptr, true, 0, num_warmup, num_iterations, true, rotating_count}, - run_flush_cache, - kernel_fn, - dim3(grid_size), - dim3(block_size), - 0, - args); - - avg_us = avg_time_ms * 1000.0f; - std::cout << " " << label << ": avg " << avg_us << " us" << std::endl; - } - else - { - launch_and_time_kernel(StreamConfig{nullptr, false}, - kernel_fn, - dim3(grid_size), - dim3(block_size), - 0, - args); - } - - d_dst.FromDevice(h_dst.data()); - bool pass = ck::utils::check_err(h_dst, h_expected); - std::cout << " " << label << " correctness: " << (pass ? "PASS" : "FAIL") << std::endl; - return {pass, avg_us}; - }; - - bool pass = true; - auto [pass_pf1, time_pf1] = - run_single(kernel_with_inst_prefetch, "single_prefetch"); - auto [pass_base, time_base] = run_single( - kernel_with_inst_prefetch, "no_prefetch (baseline)"); - - pass &= pass_base; - pass &= pass_pf1; - - if(time_kernels && time_base > 0) - { - auto pct = [&](float t) { return (t - time_base) / time_base * 100.0f; }; - std::cout << " --- Performance ---" << std::endl; - std::cout << " no_prefetch (baseline): " << time_base << " us" << std::endl; - std::cout << " single_prefetch: " << time_pf1 << " us (" << pct(time_pf1) - << " %)" << std::endl; - } - - std::cout << std::endl; - return pass; -} - -} // namespace s_prefetch_inst_op_util -} // namespace ck +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +#include "ck_tile/core/arch/inst_prefetch.hpp" + +#include + +namespace ck { +namespace s_prefetch_inst_op_util { + +template +struct KernelArgs +{ + const T* p_a_grid; + T* dst; + const T* p_b_grid; + uint32_t num_iters; +}; + +// --------------------------------------------------------------------------- +// A simple kernel that exercises INST_PREFETCH / INST_PREFETCH_TARGET macros. +// +// The kernel does: dst[tid] = src[tid] + scalar_sum +// +// Between the prefetch site and the target we place a deliberate computation +// loop so that the prefetched instruction cache lines have time to arrive. +// Correctness does not depend on prefetching -- it is pure performance hint. +// We verify correctness to ensure the asm volatile markers do not break +// code generation. +// --------------------------------------------------------------------------- + +template +__global__ void kernel_with_inst_prefetch(KernelArgs args) +{ + if constexpr(prefetch_inst_on) + { + enable_scalar_prefetch(); + // Prefetch the tail section of this kernel into L1I. + // We try to load 32 cachelines but gets clamped to smaller number inside if needed, to not + // go oob + INST_PREFETCH(INST_TEST_TAIL, 32); + } + + __builtin_amdgcn_sched_barrier(0); + + const T* src = args.p_a_grid; + T* dst = args.dst; + uint32_t num_iters = args.num_iters; + + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + T sum = 0; + + if(tid < NUM_THREADS) + { + sum += src[tid]; + } + + __builtin_amdgcn_sched_barrier(0); + + // Hot loop -- the PLACE target sits after it. + for(uint32_t iter = 0; iter < num_iters; ++iter) + { +#pragma unroll NUM_SCALARS + for(uint32_t i = 0; i < NUM_SCALARS; ++i) + { + sum += 1; + } + } + + __builtin_amdgcn_sched_barrier(0); + + INST_PREFETCH_TARGET(INST_TEST_TAIL, CK_PLACE_MODE_BLOCK_ENTRY); + +// Tail section (the code we prefetched). +#pragma unroll NUM_THREADS + for(uint32_t i = 0; i < NUM_THREADS; ++i) + { + sum += src[0]; + } + + if(tid < NUM_THREADS) + { + dst[tid] = sum; + } +} + +template +bool test_inst_prefetch_impl(bool time_kernels, const std::string& test_name) +{ + constexpr index_t num_elements = NUM_THREADS; + constexpr index_t num_scalars = 1; + constexpr index_t num_scalar_additions = NUM_SCALARS; + constexpr index_t block_size = 256; + constexpr index_t grid_size = (NUM_THREADS + block_size - 1) / block_size; + constexpr uint32_t num_iters = + 0; // always 0 so we jump to the prefetch target right after the barrier. + + std::cout << "Testing " << test_name << " for type: " << typeid(T).name() << std::endl; + std::cout << "Elements: " << num_elements + << ", Scalar additions(the one we jump over): " << num_scalar_additions << std::endl; + + // Host data + std::vector h_src(num_elements); + std::vector h_scalar(num_scalars); + std::vector h_dst(NUM_THREADS); + std::vector h_expected(NUM_THREADS); + + for(index_t i = 0; i < num_elements; i++) + { + h_src[i] = static_cast((i % 100) + 1); + } + + for(index_t i = 0; i < static_cast(NUM_THREADS); i++) + { + h_expected[i] = h_src[i] + h_src[0] * NUM_THREADS; + } + + DeviceMem d_src(sizeof(T) * num_elements); + DeviceMem d_scalar(sizeof(T) * num_scalars); + DeviceMem d_dst(sizeof(T) * NUM_THREADS); + + d_src.ToDevice(h_src.data()); + d_scalar.ToDevice(h_scalar.data()); + + KernelArgs args{static_cast(d_src.GetDeviceBuffer()), + static_cast(d_dst.GetDeviceBuffer()), + static_cast(d_scalar.GetDeviceBuffer()), + num_iters}; + + auto run_single = [&](auto kernel_fn, const std::string& label) -> std::pair { + float avg_us = 0; + + if(time_kernels) + { + constexpr int num_warmup = 10; + constexpr int num_iterations = 50; + constexpr int rotating_count = num_iterations; + auto size_a_buffer = d_src.GetBufferSize(); + auto size_b_buffer = d_scalar.GetBufferSize(); + + ck::utility::RotatingMemWrapper> rotating_mem( + args, rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + ck::utility::flush_icache(); + rotating_mem.Next(); + }; + float avg_time_ms = ck::utility::launch_and_time_kernel_with_preprocess( + StreamConfig{nullptr, true, 0, num_warmup, num_iterations, true, rotating_count}, + run_flush_cache, + kernel_fn, + dim3(grid_size), + dim3(block_size), + 0, + args); + + avg_us = avg_time_ms * 1000.0f; + std::cout << " " << label << ": avg " << avg_us << " us" << std::endl; + } + else + { + launch_and_time_kernel(StreamConfig{nullptr, false}, + kernel_fn, + dim3(grid_size), + dim3(block_size), + 0, + args); + } + + d_dst.FromDevice(h_dst.data()); + bool pass = ck::utils::check_err(h_dst, h_expected); + std::cout << " " << label << " correctness: " << (pass ? "PASS" : "FAIL") << std::endl; + return {pass, avg_us}; + }; + + bool pass = true; + auto [pass_pf1, time_pf1] = + run_single(kernel_with_inst_prefetch, "single_prefetch"); + auto [pass_base, time_base] = run_single( + kernel_with_inst_prefetch, "no_prefetch (baseline)"); + + pass &= pass_base; + pass &= pass_pf1; + + if(time_kernels && time_base > 0) + { + auto pct = [&](float t) { return (t - time_base) / time_base * 100.0f; }; + std::cout << " --- Performance ---" << std::endl; + std::cout << " no_prefetch (baseline): " << time_base << " us" << std::endl; + std::cout << " single_prefetch: " << time_pf1 << " us (" << pct(time_pf1) + << " %)" << std::endl; + } + + std::cout << std::endl; + return pass; +} + +} // namespace s_prefetch_inst_op_util +} // namespace ck diff --git a/tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp b/tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp index a764677c90..adc448746e 100644 --- a/tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp +++ b/tutorial/ck_tile/tile_distribution/tile_distribution_1.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT /* - * Tutorial: CK Tile Distribution Encoding — A Matrix DRAM Load + * Tutorial: CK Tile Distribution Encoding -- A Matrix DRAM Load * * Demonstrates how tile_distribution_encoding maps threads to A-matrix * elements during a DRAM load in the naive GEMM tutorial. @@ -10,7 +10,7 @@ * Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp * MakeADramTileDistribution(), with fp16, BlockSize=256 * - * Tile: M=256 × K=32 (matches the naive GEMM's A block tile) + * Tile: M=256 x K=32 (matches the naive GEMM's A block tile) * Threads: 256 (4 warps on CDNA, 8 on RDNA) * * Host initialises A with sequential values 0, 1, 2, ... (row-major). @@ -22,7 +22,7 @@ * The distribution encoding is hardcoded to match the fp16 derivation * (K1=16/sizeof(fp16)=8), not recomputed from sizeof(int32_t). * - * No compute is performed — this is purely about data movement. + * No compute is performed -- this is purely about data movement. * * Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32), * the thread-to-data mapping will differ. @@ -37,89 +37,89 @@ using namespace ck_tile; // ============================================================================ // THE GOAL // ============================================================================ -// Matrix A: M=256 rows × K=32 columns, stored in DRAM (row-major, fp16). +// Matrix A: M=256 rows x K=32 columns, stored in DRAM (row-major, fp16). // Load the entire tile into registers using 256 threads (4 warps on CDNA). // // For coalesced memory access with fp16, each lane loads 8 contiguous -// K-values (8 × 2 bytes = 16 bytes = 128 bits). Since K=32, we need +// K-values (8 x 2 bytes = 16 bytes = 128 bits). Since K=32, we need // 32/8 = 4 lanes to cover one row: // // lane 0: K=0..7 lane 1: K=8..15 lane 2: K=16..23 lane 3: K=24..31 -// └──────────────── one row of 32 K-columns ──────────────────────────────┘ +// +---------------- one row of 32 K-columns ------------------------------+ // // With warp_size=64, each warp has 64 lanes. 4 lanes per row means -// 64/4 = 16 rows per warp. With 4 warps, one pass covers 4×16 = 64 rows. +// 64/4 = 16 rows per warp. With 4 warps, one pass covers 4x16 = 64 rows. // To cover all 256 rows, each thread iterates M0 = 256/64 = 4 times. // -// Per-thread buffer = 4 iterations × 8 K-values = 32 elements. +// Per-thread buffer = 4 iterations x 8 K-values = 32 elements. // -// Visually for warp 0 (lanes 0–63): +// Visually for warp 0 (lanes 0-63): // -// A matrix (256×32) lane_id decomposition -// ──────────────── ────────────────────── +// A matrix (256x32) lane_id decomposition +// ---------------- ---------------------- // row 0: [ K=0..7 | 8..15 | 16..23 | 24..31 ] -// L0 L1 L2 L3 ← iter 0 +// L0 L1 L2 L3 <- iter 0 // row 1: [ K=0..7 | 8..15 | 16..23 | 24..31 ] // L4 L5 L6 L7 // ... -// row 15: same pattern, lanes 60–63 -// ────── stride of 64 rows (4 warps × 16 rows/warp) ────── -// row 64: L0..L3 ← iter 1 +// row 15: same pattern, lanes 60-63 +// ------ stride of 64 rows (4 warps x 16 rows/warp) ------ +// row 64: L0..L3 <- iter 1 // ... -// row 128: L0..L3 ← iter 2 +// row 128: L0..L3 <- iter 2 // ... -// row 192: L0..L3 ← iter 3 +// row 192: L0..L3 <- iter 3 // // ============================================================================ // THE SOLUTION: tile_distribution_encoding // ============================================================================ // // Production code derives (fp16, BlockSize=256, MPerBlock=256, KPerBlock=32): -// K1 = 16/sizeof(fp16) = 8 → vector load width (8 values) -// K0 = KPerBlock/K1 = 4 → 4 K-chunks per row -// M2 = warp_size/K0 = 16 → 16 rows per warp -// M1 = BlockSize/warp_size = 4 → 4 warps -// M0 = MPerBlock/(M2*M1) = 4 → 4 iterations +// K1 = 16/sizeof(fp16) = 8 -> vector load width (8 values) +// K0 = KPerBlock/K1 = 4 -> 4 K-chunks per row +// M2 = warp_size/K0 = 16 -> 16 rows per warp +// M1 = BlockSize/warp_size = 4 -> 4 warps +// M0 = MPerBlock/(M2*M1) = 4 -> 4 iterations // -// Step 1 — Hierarchical dimensions (Hs): factor each axis. +// Step 1 -- Hierarchical dimensions (Hs): factor each axis. // -// Hs[0] = sequence<4, 4, 16> → M = 4 × 4 × 16 = 256 -// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32 +// Hs[0] = sequence<4, 4, 16> -> M = 4 x 4 x 16 = 256 +// Hs[1] = sequence<4, 8> -> K = 4 x 8 = 32 // // Hs[0] Hs[1] -// ┌─────┼─────┐ ┌───┴───┐ +// +-----+-----+ +---+---+ // level 0 level 1 level 2 level 0 level 1 // = 4 = 4 = 16 = 4 = 8 // -// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id). +// Step 2 -- Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id). // -// P0 = warp_id → Hs[0][1] = 4 (which warp → which M-group) -// P1 = lane_id → Hs[0][2]=16 AND Hs[1][0]=4 (merged, total=64) +// P0 = warp_id -> Hs[0][1] = 4 (which warp -> which M-group) +// P1 = lane_id -> Hs[0][2]=16 AND Hs[1][0]=4 (merged, total=64) // // The merge transform decomposes lane_id: // row_in_warp = lane_id / 4 (0..15, outer) -// k_chunk = lane_id % 4 (0..3, inner → coalesced!) +// k_chunk = lane_id % 4 (0..3, inner -> coalesced!) // // Ps_major = tuple, sequence<1, 2>> // Ps_minor = tuple, sequence<2, 0>> // -// How to read Ps: the tuple has 2 elements → NDimP=2. +// How to read Ps: the tuple has 2 elements -> NDimP=2. // First element = P0 = warp_id // Second element = P1 = lane_id // // Ps_major = tuple< seq<1>, seq<1, 2> > -// ─P0(warp)─ ─P1(lane)── +// -P0(warp)- -P1(lane)-- // Ps_minor = tuple< seq<1>, seq<2, 0> > -// ─P0(warp)─ ─P1(lane)── +// -P0(warp)- -P1(lane)-- // -// P0: major=<1>, minor=<1> → Hs[0], level 1 → M1=4 -// P1: major=<1,2>, minor=<2,0> → merged: -// Hs[0] level 2 → M2=16 (outer, changes slowly) -// Hs[1] level 0 → K0=4 (inner, changes every lane → coalesced!) -// Total: 16 × 4 = 64 = warp_size -// lane / 4 → row_in_warp (M2), lane % 4 → K-chunk (K0) +// P0: major=<1>, minor=<1> -> Hs[0], level 1 -> M1=4 +// P1: major=<1,2>, minor=<2,0> -> merged: +// Hs[0] level 2 -> M2=16 (outer, changes slowly) +// Hs[1] level 0 -> K0=4 (inner, changes every lane -> coalesced!) +// Total: 16 x 4 = 64 = warp_size +// lane / 4 -> row_in_warp (M2), lane % 4 -> K-chunk (K0) // -// Step 3 — Yield dimensions (Ys): what each thread owns. +// Step 3 -- Yield dimensions (Ys): what each thread owns. // // Y0 = Hs[0][0] = 4 (M-iterations) // Y1 = Hs[1][1] = 8 (vector load width) @@ -127,27 +127,27 @@ using namespace ck_tile; // Ys_major = sequence<1, 2> // Ys_minor = sequence<0, 1> // -// How to read Ys: parallel arrays — position i gives Yi. +// How to read Ys: parallel arrays -- position i gives Yi. // -// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1] -// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1 -// ─Y0─ ─Y1─ +// Ys_major = seq< 1, 2 > -> Y0 is in Hs[0], Y1 is in Hs[1] +// Ys_minor = seq< 0, 1 > -> Y0 is level 0, Y1 is level 1 +// -Y0- -Y1- // -// Y0: Hs[0] level 0 → M0=4 (iterations along M) -// Y1: Hs[1] level 1 → K1=8 (vector load width) -// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread. +// Y0: Hs[0] level 0 -> M0=4 (iterations along M) +// Y1: Hs[1] level 1 -> K1=8 (vector load width) +// Buffer size = Y0 x Y1 = 4 x 8 = 32 elements per thread. // -// Step 4 — Replicate: Rs = sequence<1> (trivial, size 1). +// Step 4 -- Replicate: Rs = sequence<1> (trivial, size 1). // // Complete tree: // // Hs[0] Hs[1] -// ┌─────┼─────┐ ┌───┴───┐ +// +-----+-----+ +---+---+ // [Y0] [P0] [P1] [P1] [Y1] // = 4 = 4 = 16 = 4 = 8 // (iter) (warp) (row) (K-chunk) (vec load) // -// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread. +// Buffer size = Y0 x Y1 = 4 x 8 = 32 elements per thread. // // ============================================================================ @@ -182,7 +182,7 @@ struct TileDistKernelA const auto& buf = tile.get_thread_buffer(); constexpr index_t warp_size = get_warp_size(); - constexpr index_t kBufSize = 32; // 4 iterations × 8 K-values + constexpr index_t kBufSize = 32; // 4 iterations x 8 K-values int32_t local_buf[kBufSize]; static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast(buf[i]); }); @@ -232,13 +232,13 @@ struct TileDistKernelA } __syncthreads(); - // Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64, 128, 192}, K=0..7 + // Lane 0: row_in_warp=0, k_chunk=0 -> rows {0, 64, 128, 192}, K=0..7 print_thread(0); __syncthreads(); - // Lane 1: k_chunk=1 → same rows, K=8..15 (coalesced with lane 0) + // Lane 1: k_chunk=1 -> same rows, K=8..15 (coalesced with lane 0) print_thread(1); __syncthreads(); - // Lane 4: row_in_warp=1 → rows {1, 65, 129, 193}, K=0..7 + // Lane 4: row_in_warp=1 -> rows {1, 65, 129, 193}, K=0..7 print_thread(4); __syncthreads(); diff --git a/tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp b/tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp index 5d5ae3227f..ecc9dfb610 100644 --- a/tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp +++ b/tutorial/ck_tile/tile_distribution/tile_distribution_2.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT /* - * Tutorial: CK Tile Distribution Encoding — B Matrix DRAM Load + * Tutorial: CK Tile Distribution Encoding -- B Matrix DRAM Load * * Demonstrates how tile_distribution_encoding maps threads to B-matrix * elements during a DRAM load in the naive GEMM tutorial. @@ -10,7 +10,7 @@ * Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp * MakeBDramTileDistribution(), with fp16, BlockSize=256 * - * Tile: N=128 × K=32 (matches the naive GEMM's B block tile) + * Tile: N=128 x K=32 (matches the naive GEMM's B block tile) * Threads: 256 (4 warps on CDNA, 8 on RDNA) * * The B encoding has the SAME structure as the A encoding (Tutorial 1), @@ -18,7 +18,7 @@ * count), showing how the same encoding pattern adapts to different * tile sizes. * - * No compute is performed — this is purely about data movement. + * No compute is performed -- this is purely about data movement. * * Note: int32_t is used instead of fp16 for readable printf output. * The distribution encoding is hardcoded to match the fp16 derivation. @@ -36,21 +36,21 @@ using namespace ck_tile; // ============================================================================ // THE GOAL // ============================================================================ -// Matrix B: N=128 rows × K=32 columns, stored in DRAM (row-major, fp16). -// (In GEMM, B is stored as [N, K] — each "row" is one output channel.) +// Matrix B: N=128 rows x K=32 columns, stored in DRAM (row-major, fp16). +// (In GEMM, B is stored as [N, K] -- each "row" is one output channel.) // Load the entire tile into registers using 256 threads (4 warps on CDNA). // // Same coalescing strategy as the A-matrix (Tutorial 1): -// - 4 lanes cover one K-row (4 × 8 = 32 K-values) +// - 4 lanes cover one K-row (4 x 8 = 32 K-values) // - Each warp (64 lanes) covers 16 N-rows // - 4 warps cover 64 N-rows per iteration // - N0 = 128/64 = 2 iterations (vs 4 for A's M=256) // -// Per-thread buffer = 2 iterations × 8 K-values = 16 elements. +// Per-thread buffer = 2 iterations x 8 K-values = 16 elements. // // Compare with Tutorial 1 (A-matrix): // A: M=256, M0=4, buffer=32 | B: N=128, N0=2, buffer=16 -// Everything else is identical — same K-splitting, same coalescing. +// Everything else is identical -- same K-splitting, same coalescing. // // ============================================================================ // THE SOLUTION: tile_distribution_encoding @@ -63,47 +63,47 @@ using namespace ck_tile; // N1 = BlockSize/warp_size = 4 // N0 = NPerBlock/(N2*N1) = 2 // -// Step 1 — Hierarchical dimensions (Hs): +// Step 1 -- Hierarchical dimensions (Hs): // -// Hs[0] = sequence<2, 4, 16> → N = 2 × 4 × 16 = 128 -// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32 +// Hs[0] = sequence<2, 4, 16> -> N = 2 x 4 x 16 = 128 +// Hs[1] = sequence<4, 8> -> K = 4 x 8 = 32 // // Hs[0] Hs[1] -// ┌─────┼─────┐ ┌───┴───┐ +// +-----+-----+ +---+---+ // [Y0] [P0] [P1] [P1] [Y1] // = 2 = 4 = 16 = 4 = 8 // (iter) (warp) (row) (K-chunk) (vec load) // -// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id). +// Step 2 -- Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id). // // Ps_major = tuple, sequence<1, 2>> // Ps_minor = tuple, sequence<2, 0>> // -// How to read Ps: the tuple has 2 elements → NDimP=2. +// How to read Ps: the tuple has 2 elements -> NDimP=2. // First element = P0 = warp_id // Second element = P1 = lane_id // -// P0: major=<1>, minor=<1> → Hs[0], level 1 → N1=4 (which warp) -// P1: major=<1,2>, minor=<2,0> → merged: -// Hs[0] level 2 → N2=16 (outer, row within warp) -// Hs[1] level 0 → K0=4 (inner, K-chunk → coalesced!) -// lane / 4 → row_in_warp, lane % 4 → K-chunk +// P0: major=<1>, minor=<1> -> Hs[0], level 1 -> N1=4 (which warp) +// P1: major=<1,2>, minor=<2,0> -> merged: +// Hs[0] level 2 -> N2=16 (outer, row within warp) +// Hs[1] level 0 -> K0=4 (inner, K-chunk -> coalesced!) +// lane / 4 -> row_in_warp, lane % 4 -> K-chunk // -// Step 3 — Yield dimensions (Ys): what each thread owns. +// Step 3 -- Yield dimensions (Ys): what each thread owns. // // Ys_major = sequence<1, 2> // Ys_minor = sequence<0, 1> // -// How to read Ys: parallel arrays — position i gives Yi. +// How to read Ys: parallel arrays -- position i gives Yi. // -// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1] -// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1 -// ─Y0─ ─Y1─ +// Ys_major = seq< 1, 2 > -> Y0 is in Hs[0], Y1 is in Hs[1] +// Ys_minor = seq< 0, 1 > -> Y0 is level 0, Y1 is level 1 +// -Y0- -Y1- // -// Y0: Hs[0] level 0 → N0=2 (iterations along N) -// Y1: Hs[1] level 1 → K1=8 (vector load width) +// Y0: Hs[0] level 0 -> N0=2 (iterations along N) +// Y1: Hs[1] level 1 -> K1=8 (vector load width) // -// Buffer size = Y0 × Y1 = 2 × 8 = 16 elements per thread. +// Buffer size = Y0 x Y1 = 2 x 8 = 16 elements per thread. // // ============================================================================ @@ -138,7 +138,7 @@ struct TileDistKernelB const auto& buf = tile.get_thread_buffer(); constexpr index_t warp_size = get_warp_size(); - constexpr index_t kBufSize = 16; // 2 iterations × 8 K-values + constexpr index_t kBufSize = 16; // 2 iterations x 8 K-values int32_t local_buf[kBufSize]; static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast(buf[i]); }); @@ -187,13 +187,13 @@ struct TileDistKernelB } __syncthreads(); - // Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64}, K=0..7 + // Lane 0: row_in_warp=0, k_chunk=0 -> rows {0, 64}, K=0..7 print_thread(0); __syncthreads(); - // Lane 1: k_chunk=1 → same rows, K=8..15 + // Lane 1: k_chunk=1 -> same rows, K=8..15 print_thread(1); __syncthreads(); - // Lane 4: row_in_warp=1 → rows {1, 65}, K=0..7 + // Lane 4: row_in_warp=1 -> rows {1, 65}, K=0..7 print_thread(4); __syncthreads(); diff --git a/tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp b/tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp index 4a782b592b..32c2b3cd78 100644 --- a/tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp +++ b/tutorial/ck_tile/tile_distribution/tile_distribution_3.cpp @@ -2,12 +2,12 @@ // SPDX-License-Identifier: MIT /* - * Tutorial: CK Tile Distribution Encoding — C Matrix Register Layout + * Tutorial: CK Tile Distribution Encoding -- C Matrix Register Layout * * Demonstrates how C-matrix elements are distributed across thread registers * after MFMA computation. Unlike A/B (which are DRAM loads), C lives entirely - * in registers — the distribution describes which thread holds which output - * element of C = A × B. + * in registers -- the distribution describes which thread holds which output + * element of C = A x B. * * This tutorial shows BOTH: * 1. The warp-level C distribution (from MFMA m32n32k8 output mapping) @@ -17,11 +17,11 @@ * The macro CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION (default 1) selects * between the standard and transposed C register layouts. * - * Tile: M=256 × N=128 (matches the naive GEMM's C block tile) + * Tile: M=256 x N=128 (matches the naive GEMM's C block tile) * Warp config: MWarp=4, NWarp=1 - * MFMA: m32n32k8 (each warp produces a 32×32 output) + * MFMA: m32n32k8 (each warp produces a 32x32 output) * - * No actual MFMA compute — we construct a C distributed tensor, fill it + * No actual MFMA compute -- we construct a C distributed tensor, fill it * with marker values (thread_id * 1000 + buffer_index), and print per-thread * contents to reveal which buffer slots belong to which thread. * @@ -45,24 +45,24 @@ using namespace ck_tile; // THE GOAL // ============================================================================ // After GEMM computation, each thread holds a subset of the C matrix -// (M=256 × N=128 = 32768 elements) in its registers. We want to understand +// (M=256 x N=128 = 32768 elements) in its registers. We want to understand // exactly which C[m][n] elements each thread owns. // // The mapping has two levels: // -// BLOCK LEVEL (256×128 → warps and iterations): +// BLOCK LEVEL (256x128 -> warps and iterations): // - 4 warps along M (MWarp=4), 1 warp along N (NWarp=1) -// - Each warp covers 32 M-rows × 128 N-cols of the block tile +// - Each warp covers 32 M-rows x 128 N-cols of the block tile // - Within each warp: MIterPerWarp=2, NIterPerWarp=4 -// → 2 × 4 = 8 warp-tile iterations per warp -// - Each warp-tile iteration is a 32×32 MFMA output +// -> 2 x 4 = 8 warp-tile iterations per warp +// - Each warp-tile iteration is a 32x32 MFMA output // -// WARP LEVEL (32×32 → threads): -// - 64 threads produce 32 × 32 = 1024 C elements +// WARP LEVEL (32x32 -> threads): +// - 64 threads produce 32 x 32 = 1024 C elements // - Each thread holds 1024/64 = 16 elements // - MFMA m32n32k8 arranges these 16 elements in a specific pattern // -// The per-thread register buffer = 8 iterations × 16 elements = 128 floats. +// The per-thread register buffer = 8 iterations x 16 elements = 128 floats. // // ============================================================================ // THE SOLUTION: Two-Level Distribution @@ -70,105 +70,105 @@ using namespace ck_tile; // // --- WARP-LEVEL C DISTRIBUTION (from MFMA m32n32k8) --- // -// For fp16→fp32 MFMA m32n32k8 output (kCM0PerLane=4, kCMLane=2, +// For fp16->fp32 MFMA m32n32k8 output (kCM0PerLane=4, kCMLane=2, // kCM1PerLane=4, kCNLane=32): // // STANDARD (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0): // -// Hs[0] = sequence<4, 2, 4> → M-dim: 4 × 2 × 4 = 32 -// Hs[1] = sequence<32> → N-dim: 32 -// Ps_major = tuple> → lane maps to Hs[0][1] and Hs[1][0] +// Hs[0] = sequence<4, 2, 4> -> M-dim: 4 x 2 x 4 = 32 +// Hs[1] = sequence<32> -> N-dim: 32 +// Ps_major = tuple> -> lane maps to Hs[0][1] and Hs[1][0] // Ps_minor = tuple> // -// How to read Ps: the tuple has 1 element → NDimP=1 → P0 = lane_id. -// P0: major=<1,2>, minor=<1,0> → merged: -// Hs[0] level 1 → kCMLane=2 (outer, M-half) -// Hs[1] level 0 → kCNLane=32 (inner, N-col → contiguous!) -// lane / 32 → M-half, lane % 32 → N-col +// How to read Ps: the tuple has 1 element -> NDimP=1 -> P0 = lane_id. +// P0: major=<1,2>, minor=<1,0> -> merged: +// Hs[0] level 1 -> kCMLane=2 (outer, M-half) +// Hs[1] level 0 -> kCNLane=32 (inner, N-col -> contiguous!) +// lane / 32 -> M-half, lane % 32 -> N-col // // Ys_major = sequence<1, 1> // Ys_minor = sequence<0, 2> // -// How to read Ys: parallel arrays — position i gives Yi. +// How to read Ys: parallel arrays -- position i gives Yi. // -// Ys_major = seq< 1, 1 > → Y0 is in Hs[0], Y1 is in Hs[0] -// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2 -// ─Y0─ ─Y1─ +// Ys_major = seq< 1, 1 > -> Y0 is in Hs[0], Y1 is in Hs[0] +// Ys_minor = seq< 0, 2 > -> Y0 is level 0, Y1 is level 2 +// -Y0- -Y1- // -// Y0: Hs[0] level 0 → kCM0PerLane=4 (M outer per lane) -// Y1: Hs[0] level 2 → kCM1PerLane=4 (M inner per lane) +// Y0: Hs[0] level 0 -> kCM0PerLane=4 (M outer per lane) +// Y1: Hs[0] level 2 -> kCM1PerLane=4 (M inner per lane) // // Hs[0] Hs[1] -// ┌─────┼─────┐ │ +// +-----+-----+ | // [Y0] [P0] [Y1] [P0] // = 4 = 2 = 4 = 32 -// (M outer)(lane) (M inner) (lane → N) +// (M outer)(lane) (M inner) (lane -> N) // -// Per-thread: Y0 × Y1 = 4 × 4 = 16 elements per warp-tile. -// Lane decomposition: lane / 32 → M-half (0..1), lane % 32 → N-col (0..31) +// Per-thread: Y0 x Y1 = 4 x 4 = 16 elements per warp-tile. +// Lane decomposition: lane / 32 -> M-half (0..1), lane % 32 -> N-col (0..31) // // TRANSPOSED (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1): // -// Hs[0] = sequence<32> → First dim: N (swapped!) -// Hs[1] = sequence<4, 2, 4> → Second dim: M (swapped!) -// Ps_major = tuple> → lane maps to Hs[1][1] and Hs[0][0] +// Hs[0] = sequence<32> -> First dim: N (swapped!) +// Hs[1] = sequence<4, 2, 4> -> Second dim: M (swapped!) +// Ps_major = tuple> -> lane maps to Hs[1][1] and Hs[0][0] // Ps_minor = tuple> // -// How to read Ps: tuple has 1 element → NDimP=1 → P0 = lane_id. -// P0: major=<2,1>, minor=<1,0> → merged: -// Hs[1] level 1 → kCMLane=2 (outer, M-half) -// Hs[0] level 0 → kCNLane=32 (inner, N-col → contiguous!) +// How to read Ps: tuple has 1 element -> NDimP=1 -> P0 = lane_id. +// P0: major=<2,1>, minor=<1,0> -> merged: +// Hs[1] level 1 -> kCMLane=2 (outer, M-half) +// Hs[0] level 0 -> kCNLane=32 (inner, N-col -> contiguous!) // Same lane decomposition as standard, but dimensions are swapped. // // Ys_major = sequence<2, 2> // Ys_minor = sequence<0, 2> // // How to read Ys: -// Ys_major = seq< 2, 2 > → Y0 is in Hs[1], Y1 is in Hs[1] -// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2 -// ─Y0─ ─Y1─ +// Ys_major = seq< 2, 2 > -> Y0 is in Hs[1], Y1 is in Hs[1] +// Ys_minor = seq< 0, 2 > -> Y0 is level 0, Y1 is level 2 +// -Y0- -Y1- // -// Y0: Hs[1] level 0 → kCM0PerLane=4 (M outer per lane) -// Y1: Hs[1] level 2 → kCM1PerLane=4 (M inner per lane) +// Y0: Hs[1] level 0 -> kCM0PerLane=4 (M outer per lane) +// Y1: Hs[1] level 2 -> kCM1PerLane=4 (M inner per lane) // Same 16 elements, but now both Y dims are in Hs[1] (M is second). // // Hs[0] Hs[1] -// │ ┌─────┼─────┐ +// | +-----+-----+ // [P0] [Y0] [P0] [Y1] // = 32 = 4 = 2 = 4 -// (lane → N) (M outer)(lane)(M inner) +// (lane -> N) (M outer)(lane)(M inner) // // Same 16 elements per thread, but N is the first dimension in the -// distribution — this changes which elements are contiguous in the +// distribution -- this changes which elements are contiguous in the // thread buffer, affecting downstream store coalescing. // // --- BLOCK-LEVEL OUTER DISTRIBUTION --- // -// MIterPerWarp = MPerBlock / (MWarp × WarpGemm::kM) = 256 / (4 × 32) = 2 -// NIterPerWarp = NPerBlock / (NWarp × WarpGemm::kN) = 128 / (1 × 32) = 4 +// MIterPerWarp = MPerBlock / (MWarp x WarpGemm::kM) = 256 / (4 x 32) = 2 +// NIterPerWarp = NPerBlock / (NWarp x WarpGemm::kN) = 128 / (1 x 32) = 4 // -// Hs[0] = sequence<2, 4> → M-dim: 2 iters × 4 warps -// Hs[1] = sequence<4, 1> → N-dim: 4 iters × 1 warp +// Hs[0] = sequence<2, 4> -> M-dim: 2 iters x 4 warps +// Hs[1] = sequence<4, 1> -> N-dim: 4 iters x 1 warp // Ps_major = tuple> // Ps_minor = tuple> // -// How to read Ps: tuple has 1 element → NDimP=1 → P0 = warp_id. -// P0: major=<1,2>, minor=<1,1> → merged: -// Hs[0] level 1 → MWarp=4 (outer) -// Hs[1] level 1 → NWarp=1 (inner, trivial) -// Total: 4 × 1 = 4 = number of warps +// How to read Ps: tuple has 1 element -> NDimP=1 -> P0 = warp_id. +// P0: major=<1,2>, minor=<1,1> -> merged: +// Hs[0] level 1 -> MWarp=4 (outer) +// Hs[1] level 1 -> NWarp=1 (inner, trivial) +// Total: 4 x 1 = 4 = number of warps // // Ys_major = sequence<1, 2> // Ys_minor = sequence<0, 0> // // How to read Ys: -// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1] -// Ys_minor = seq< 0, 0 > → Y0 is level 0, Y1 is level 0 -// ─Y0─ ─Y1─ +// Ys_major = seq< 1, 2 > -> Y0 is in Hs[0], Y1 is in Hs[1] +// Ys_minor = seq< 0, 0 > -> Y0 is level 0, Y1 is level 0 +// -Y0- -Y1- // -// Y0: Hs[0] level 0 → MIterPerWarp=2 -// Y1: Hs[1] level 0 → NIterPerWarp=4 -// Block-level buffer = Y0 × Y1 = 2 × 4 = 8 warp-tile slots. +// Y0: Hs[0] level 0 -> MIterPerWarp=2 +// Y1: Hs[1] level 0 -> NIterPerWarp=4 +// Block-level buffer = Y0 x Y1 = 2 x 4 = 8 warp-tile slots. // // tile_distribution_encoding, // tuple, sequence<4, 1>>, @@ -179,7 +179,7 @@ using namespace ck_tile; // // make_embed_tile_distribution_encoding(block_outer, warp_encoding) // embeds the warp encoding inside each (MIter, MWarp, NIter, NWarp) cell. -// Total per-thread buffer = 2 × 4 × 16 = 128 elements. +// Total per-thread buffer = 2 x 4 x 16 = 128 elements. // // ============================================================================ @@ -354,10 +354,10 @@ int main() printf("=== CK Tile Distribution Tutorial 3: C-Matrix Register Layout ===\n"); printf("=== Matches naive GEMM: MPerBlock=256, NPerBlock=128 ===\n\n"); printf("MFMA m32n32k8: each warp produces 32x32 = 1024 elements\n"); - printf(" 64 threads per warp → 16 elements per thread per warp-tile\n"); - printf(" MWarp=4, NWarp=1 → 4 warps along M, 1 along N\n"); - printf(" MIterPerWarp=2, NIterPerWarp=4 → 8 warp-tiles per warp\n"); - printf(" Total per thread: 8 × 16 = 128 elements\n\n"); + printf(" 64 threads per warp -> 16 elements per thread per warp-tile\n"); + printf(" MWarp=4, NWarp=1 -> 4 warps along M, 1 along N\n"); + printf(" MIterPerWarp=2, NIterPerWarp=4 -> 8 warp-tiles per warp\n"); + printf(" Total per thread: 8 x 16 = 128 elements\n\n"); #if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION printf("Current mode: TRANSPOSED C distribution\n");