mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
[tile_engine] Integrate gemm_streamk into budget-based sampling system (#8079) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation `gemm_streamk` was the only GEMM op not participating in the tile engine's budget-based sampling system. Without a budget cap, it would always generate its full feasible set, making build times unpredictable and inconsistent with the other ops. ## Technical Details - **CMake budget propagation** (`ops/gemm/CMakeLists.txt`): Added `gemm_streamk` to the active-ops detection loop so it receives a share of the sampling budget. Because `gemm_streamk` lives in a sibling subdirectory (`ops/gemm_streamk/`), its allocation is written via `CACHE STRING "" FORCE` to make the variable visible across the CMake directory boundary. - **Per-combo budget division** (`ops/gemm_streamk/CMakeLists.txt`, `ops/gemm/grouped_gemm/CMakeLists.txt`): Added the same per-combo `MAX_INSTANCES` division that exists in `gemm_universal` and `gemm_preshuffle`. The total budget is divided by `n_datatypes × n_layouts` before the inner `foreach` loop so that sampling fires independently per `(dtype, layout)` combo rather than acting as a single global cap. - **Sampling integration** (`gemm_streamk_instance_builder.py`): Added `_apply_sampling()` method to `GemmKernelBuilder`, mirroring the Sobol+LHS+maximin sampling used by other ops. New constructor parameters: `gpu_target`, `max_instances`, `seed`, `tier`, `manifest_path`. New CLI arguments: `--gpu_target`, `--max-instances`, `--seed`, `--tier`, `--manifest-path`. The `--gpu_target` argument is now also forwarded on the `--list_kernels` invocation. - **`GEMM_STREAMK_AXES`** (`sampling/feasible_set.py`): Defined as `GEMM_AXES + ["reduction_strategy"]` to account for the extra axis unique to stream-K. Added `reduction_strategy` to `CATEGORICAL_AXES`. - **Weight rebalancing** (`sampling/op_weights.json`): Allocated 10% weight to `gemm_streamk` by proportionally reducing `gemm_universal` (0.35 → 0.30) and `gemm_preshuffle` (0.30 → 0.25). Total remains 1.00. ## Test Plan - Configure with `TILE_ENGINE_SAMPLING_TIER=daily` and verify that `gemm_streamk` receives a non-zero budget allocation and that `GEMM_STREAMK_MAX_INSTANCES` is set correctly. - Configure with `TILE_ENGINE_SAMPLING_TIER=daily` across multiple `(dtype, layout)` combos and confirm per-combo budget = total / n_combos. - Configure with `-DGEMM_STREAMK_MAX_INSTANCES=50` explicit override and verify the override is respected (budget allocation skipped). - Verify `chosen_instances.json` manifest is written to the working path when tier is active. - Confirm `op_weights.json` weights still sum to 1.00. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
86 lines
2.1 KiB
Python
86 lines
2.1 KiB
Python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
GEMM_AXES = [
|
|
"tile_m",
|
|
"tile_n",
|
|
"tile_k",
|
|
"warp_m",
|
|
"warp_n",
|
|
"warp_k",
|
|
"warp_tile_m",
|
|
"warp_tile_n",
|
|
"warp_tile_k",
|
|
"pipeline",
|
|
"epilogue",
|
|
"scheduler",
|
|
"pad_m",
|
|
"pad_n",
|
|
"pad_k",
|
|
"persistent",
|
|
]
|
|
|
|
GEMM_STREAMK_AXES = GEMM_AXES + ["reduction_strategy"]
|
|
|
|
CATEGORICAL_AXES = {
|
|
"pipeline",
|
|
"epilogue",
|
|
"scheduler",
|
|
"reduction_strategy",
|
|
"pad_m",
|
|
"pad_n",
|
|
"pad_k",
|
|
"persistent",
|
|
}
|
|
|
|
|
|
def normalize_axis_values(feasible_set, axes=None):
|
|
"""Compute normalization metadata for each axis.
|
|
|
|
Returns dict mapping axis name to:
|
|
- For numeric axes: {"type": "numeric", "min": v, "max": v, "range": v}
|
|
- For categorical axes: {"type": "categorical", "values": sorted list, "map": value->index}
|
|
"""
|
|
if axes is None:
|
|
axes = GEMM_AXES
|
|
|
|
meta = {}
|
|
for ax in axes:
|
|
values = [item[ax] for item in feasible_set if ax in item]
|
|
if not values:
|
|
continue
|
|
|
|
if ax in CATEGORICAL_AXES:
|
|
unique = sorted(set(str(v) for v in values))
|
|
meta[ax] = {
|
|
"type": "categorical",
|
|
"values": unique,
|
|
"map": {v: i for i, v in enumerate(unique)},
|
|
"count": len(unique),
|
|
}
|
|
else:
|
|
num_values = [float(v) for v in values]
|
|
mn, mx = min(num_values), max(num_values)
|
|
meta[ax] = {
|
|
"type": "numeric",
|
|
"min": mn,
|
|
"max": mx,
|
|
"range": mx - mn if mx != mn else 1.0,
|
|
}
|
|
return meta
|
|
|
|
|
|
def normalize_point(item, axes, meta):
|
|
"""Normalize a single point to [0, 1] per axis."""
|
|
coords = []
|
|
for ax in axes:
|
|
if ax not in meta or ax not in item:
|
|
coords.append(0.0)
|
|
continue
|
|
m = meta[ax]
|
|
if m["type"] == "numeric":
|
|
coords.append((float(item[ax]) - m["min"]) / m["range"])
|
|
else:
|
|
coords.append(m["map"].get(str(item[ax]), 0) / max(m["count"] - 1, 1))
|
|
return coords
|