Files
composable_kernel/tile_engine/sampling/feasible_set.py
Thrupti Raj Lakshmana Gowda 054436ca4a [rocm-libraries] ROCm/rocm-libraries#8079 (commit cf1e8f2)
[tile_engine] Integrate gemm_streamk into budget-based
 sampling system (#8079)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

`gemm_streamk` was the only GEMM op not participating in the tile
engine's budget-based sampling system. Without a budget cap, it would
always generate its full feasible set, making build times unpredictable
and inconsistent with the other ops.

## Technical Details

- **CMake budget propagation** (`ops/gemm/CMakeLists.txt`): Added
`gemm_streamk` to the active-ops detection loop so it receives a share
of the sampling budget. Because `gemm_streamk` lives in a sibling
subdirectory (`ops/gemm_streamk/`), its allocation is written via `CACHE
STRING "" FORCE` to make the variable visible across the CMake directory
boundary.

- **Per-combo budget division** (`ops/gemm_streamk/CMakeLists.txt`,
`ops/gemm/grouped_gemm/CMakeLists.txt`): Added the same per-combo
`MAX_INSTANCES` division that exists in `gemm_universal` and
`gemm_preshuffle`. The total budget is divided by `n_datatypes ×
n_layouts` before the inner `foreach` loop so that sampling fires
independently per `(dtype, layout)` combo rather than acting as a single
global cap.

- **Sampling integration** (`gemm_streamk_instance_builder.py`): Added
`_apply_sampling()` method to `GemmKernelBuilder`, mirroring the
Sobol+LHS+maximin sampling used by other ops. New constructor
parameters: `gpu_target`, `max_instances`, `seed`, `tier`,
`manifest_path`. New CLI arguments: `--gpu_target`, `--max-instances`,
`--seed`, `--tier`, `--manifest-path`. The `--gpu_target` argument is
now also forwarded on the `--list_kernels` invocation.

- **`GEMM_STREAMK_AXES`** (`sampling/feasible_set.py`): Defined as
`GEMM_AXES + ["reduction_strategy"]` to account for the extra axis
unique to stream-K. Added `reduction_strategy` to `CATEGORICAL_AXES`.

- **Weight rebalancing** (`sampling/op_weights.json`): Allocated 10%
weight to `gemm_streamk` by proportionally reducing `gemm_universal`
(0.35 → 0.30) and `gemm_preshuffle` (0.30 → 0.25). Total remains 1.00.

## Test Plan

- Configure with `TILE_ENGINE_SAMPLING_TIER=daily` and verify that
`gemm_streamk` receives a non-zero budget allocation and that
`GEMM_STREAMK_MAX_INSTANCES` is set correctly.
- Configure with `TILE_ENGINE_SAMPLING_TIER=daily` across multiple
`(dtype, layout)` combos and confirm per-combo budget = total /
n_combos.
- Configure with `-DGEMM_STREAMK_MAX_INSTANCES=50` explicit override and
verify the override is respected (budget allocation skipped).
- Verify `chosen_instances.json` manifest is written to the working path
when tier is active.
- Confirm `op_weights.json` weights still sum to 1.00.

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-05 17:06:11 +00:00

86 lines
2.1 KiB
Python

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
GEMM_AXES = [
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
"pipeline",
"epilogue",
"scheduler",
"pad_m",
"pad_n",
"pad_k",
"persistent",
]
GEMM_STREAMK_AXES = GEMM_AXES + ["reduction_strategy"]
CATEGORICAL_AXES = {
"pipeline",
"epilogue",
"scheduler",
"reduction_strategy",
"pad_m",
"pad_n",
"pad_k",
"persistent",
}
def normalize_axis_values(feasible_set, axes=None):
"""Compute normalization metadata for each axis.
Returns dict mapping axis name to:
- For numeric axes: {"type": "numeric", "min": v, "max": v, "range": v}
- For categorical axes: {"type": "categorical", "values": sorted list, "map": value->index}
"""
if axes is None:
axes = GEMM_AXES
meta = {}
for ax in axes:
values = [item[ax] for item in feasible_set if ax in item]
if not values:
continue
if ax in CATEGORICAL_AXES:
unique = sorted(set(str(v) for v in values))
meta[ax] = {
"type": "categorical",
"values": unique,
"map": {v: i for i, v in enumerate(unique)},
"count": len(unique),
}
else:
num_values = [float(v) for v in values]
mn, mx = min(num_values), max(num_values)
meta[ax] = {
"type": "numeric",
"min": mn,
"max": mx,
"range": mx - mn if mx != mn else 1.0,
}
return meta
def normalize_point(item, axes, meta):
"""Normalize a single point to [0, 1] per axis."""
coords = []
for ax in axes:
if ax not in meta or ax not in item:
coords.append(0.0)
continue
m = meta[ax]
if m["type"] == "numeric":
coords.append((float(item[ax]) - m["min"]) / m["range"])
else:
coords.append(m["map"].get(str(item[ax]), 0) / max(m["count"] - 1, 1))
return coords