mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
[CK Tile Engine] Daily tier sampling for tile engine GEMM (#7311) Summary - Replace uniform random instance sampling (random.shuffle) with scrambled Sobol + Latin Hypercube + maximin space-filling sampling, per the Tile Engine Benchmark Sampling RFC - Add op-weighted budget allocation via new TILE_ENGINE_SAMPLING_TIER=daily CMake knob that auto-distributes 8,000 instances across ops proportional to registered weights in op_weights.json - Emit chosen_instances.json manifests for reproducibility tracking - Consolidate 5 copies of sampling logic into single _apply_sampling() method on the base class Jenkinsfile changes Replace per-op -D *_MAX_INSTANCES=250 with single -D TILE_ENGINE_SAMPLING_TIER=daily in gfx942/gfx950/gfx1201 stages. Budget auto-distributes (8000 total per GPU target). --------- Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""Latin Hypercube Sampling padding for marginal coverage.
|
|
|
|
Ensures every distinct value on every parameter axis appears at least once
|
|
in the selected sample, using a greedy set-cover heuristic.
|
|
"""
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
def lhs_pad(selected_indices, feasible_set, axes, budget_remaining, rng):
|
|
"""Add indices to guarantee marginal coverage on all axes.
|
|
|
|
Args:
|
|
selected_indices: Already-selected feasible-set indices.
|
|
feasible_set: Full list of parameter dicts.
|
|
axes: List of axis names to ensure coverage for.
|
|
budget_remaining: Max additional indices to add.
|
|
rng: random.Random instance for tie-breaking.
|
|
|
|
Returns:
|
|
List of additional indices to include.
|
|
"""
|
|
if budget_remaining <= 0:
|
|
return []
|
|
|
|
selected_set = set(selected_indices)
|
|
|
|
# Build per-axis coverage maps: axis -> value -> set of feasible indices with that value
|
|
axis_value_indices = {}
|
|
for ax in axes:
|
|
value_map = defaultdict(set)
|
|
for i, item in enumerate(feasible_set):
|
|
if ax in item:
|
|
value_map[str(item[ax])].add(i)
|
|
axis_value_indices[ax] = value_map
|
|
|
|
# Find which axis values are already covered
|
|
covered = {}
|
|
for ax in axes:
|
|
covered[ax] = set()
|
|
for idx in selected_set:
|
|
if ax in feasible_set[idx]:
|
|
covered[ax].add(str(feasible_set[idx][ax]))
|
|
|
|
# Find uncovered axis values
|
|
uncovered_pairs = [] # (axis, value) pairs not yet covered
|
|
for ax in axes:
|
|
for val in axis_value_indices[ax]:
|
|
if val not in covered[ax]:
|
|
uncovered_pairs.append((ax, val))
|
|
|
|
if not uncovered_pairs:
|
|
return []
|
|
|
|
# Greedy set-cover: pick indices that cover the most uncovered (axis, value) pairs
|
|
additional = []
|
|
uncovered_set = set(range(len(uncovered_pairs)))
|
|
|
|
while uncovered_set and len(additional) < budget_remaining:
|
|
# For each candidate index, count how many uncovered pairs it covers
|
|
best_idx = -1
|
|
best_count = 0
|
|
best_covers = set()
|
|
|
|
# Build candidate pool: indices that appear in at least one uncovered pair's index set
|
|
candidates = set()
|
|
for ui in uncovered_set:
|
|
ax, val = uncovered_pairs[ui]
|
|
candidates.update(axis_value_indices[ax][val])
|
|
candidates -= selected_set
|
|
candidates -= set(additional)
|
|
|
|
if not candidates:
|
|
break
|
|
|
|
# Sample a subset to avoid O(N*U) when both are large
|
|
candidate_list = list(candidates)
|
|
if len(candidate_list) > 500:
|
|
rng.shuffle(candidate_list)
|
|
candidate_list = candidate_list[:500]
|
|
|
|
for ci in candidate_list:
|
|
item = feasible_set[ci]
|
|
covers = set()
|
|
for ui in uncovered_set:
|
|
ax, val = uncovered_pairs[ui]
|
|
if ax in item and str(item[ax]) == val:
|
|
covers.add(ui)
|
|
if len(covers) > best_count:
|
|
best_count = len(covers)
|
|
best_idx = ci
|
|
best_covers = covers
|
|
|
|
if best_idx < 0:
|
|
break
|
|
|
|
additional.append(best_idx)
|
|
uncovered_set -= best_covers
|
|
# Update covered sets
|
|
item = feasible_set[best_idx]
|
|
for ax in axes:
|
|
if ax in item:
|
|
covered[ax].add(str(item[ax]))
|
|
|
|
return additional
|