mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 02:37:01 +00:00
[tile_engine] Integrate gemm_streamk into budget-based sampling system (#8079) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation `gemm_streamk` was the only GEMM op not participating in the tile engine's budget-based sampling system. Without a budget cap, it would always generate its full feasible set, making build times unpredictable and inconsistent with the other ops. ## Technical Details - **CMake budget propagation** (`ops/gemm/CMakeLists.txt`): Added `gemm_streamk` to the active-ops detection loop so it receives a share of the sampling budget. Because `gemm_streamk` lives in a sibling subdirectory (`ops/gemm_streamk/`), its allocation is written via `CACHE STRING "" FORCE` to make the variable visible across the CMake directory boundary. - **Per-combo budget division** (`ops/gemm_streamk/CMakeLists.txt`, `ops/gemm/grouped_gemm/CMakeLists.txt`): Added the same per-combo `MAX_INSTANCES` division that exists in `gemm_universal` and `gemm_preshuffle`. The total budget is divided by `n_datatypes × n_layouts` before the inner `foreach` loop so that sampling fires independently per `(dtype, layout)` combo rather than acting as a single global cap. - **Sampling integration** (`gemm_streamk_instance_builder.py`): Added `_apply_sampling()` method to `GemmKernelBuilder`, mirroring the Sobol+LHS+maximin sampling used by other ops. New constructor parameters: `gpu_target`, `max_instances`, `seed`, `tier`, `manifest_path`. New CLI arguments: `--gpu_target`, `--max-instances`, `--seed`, `--tier`, `--manifest-path`. The `--gpu_target` argument is now also forwarded on the `--list_kernels` invocation. - **`GEMM_STREAMK_AXES`** (`sampling/feasible_set.py`): Defined as `GEMM_AXES + ["reduction_strategy"]` to account for the extra axis unique to stream-K. Added `reduction_strategy` to `CATEGORICAL_AXES`. - **Weight rebalancing** (`sampling/op_weights.json`): Allocated 10% weight to `gemm_streamk` by proportionally reducing `gemm_universal` (0.35 → 0.30) and `gemm_preshuffle` (0.30 → 0.25). Total remains 1.00. ## Test Plan - Configure with `TILE_ENGINE_SAMPLING_TIER=daily` and verify that `gemm_streamk` receives a non-zero budget allocation and that `GEMM_STREAMK_MAX_INSTANCES` is set correctly. - Configure with `TILE_ENGINE_SAMPLING_TIER=daily` across multiple `(dtype, layout)` combos and confirm per-combo budget = total / n_combos. - Configure with `-DGEMM_STREAMK_MAX_INSTANCES=50` explicit override and verify the override is respected (budget allocation skipped). - Verify `chosen_instances.json` manifest is written to the working path when tier is active. - Confirm `op_weights.json` weights still sum to 1.00. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.