Files
composable_kernel/tile_engine/sampling/lhs.py
Thrupti Raj Lakshmana Gowda c31fc4df52 [rocm-libraries] ROCm/rocm-libraries#7311 (commit 79d8cae)
[CK Tile Engine] Daily tier sampling for tile engine GEMM  (#7311)

Summary
- Replace uniform random instance sampling (random.shuffle) with
scrambled Sobol + Latin Hypercube + maximin space-filling
sampling, per the Tile Engine Benchmark Sampling RFC
- Add op-weighted budget allocation via new
TILE_ENGINE_SAMPLING_TIER=daily CMake knob that auto-distributes 8,000
instances across
ops proportional to registered weights in op_weights.json
  - Emit chosen_instances.json manifests for reproducibility tracking
- Consolidate 5 copies of sampling logic into single _apply_sampling()
method on the base class
Jenkinsfile changes
Replace per-op -D *_MAX_INSTANCES=250 with single -D
TILE_ENGINE_SAMPLING_TIER=daily in gfx942/gfx950/gfx1201 stages. Budget
  auto-distributes (8000 total per GPU target).

---------

Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
2026-05-21 02:17:42 -05:00

109 lines
3.4 KiB
Python

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Latin Hypercube Sampling padding for marginal coverage.
Ensures every distinct value on every parameter axis appears at least once
in the selected sample, using a greedy set-cover heuristic.
"""
from collections import defaultdict
def lhs_pad(selected_indices, feasible_set, axes, budget_remaining, rng):
"""Add indices to guarantee marginal coverage on all axes.
Args:
selected_indices: Already-selected feasible-set indices.
feasible_set: Full list of parameter dicts.
axes: List of axis names to ensure coverage for.
budget_remaining: Max additional indices to add.
rng: random.Random instance for tie-breaking.
Returns:
List of additional indices to include.
"""
if budget_remaining <= 0:
return []
selected_set = set(selected_indices)
# Build per-axis coverage maps: axis -> value -> set of feasible indices with that value
axis_value_indices = {}
for ax in axes:
value_map = defaultdict(set)
for i, item in enumerate(feasible_set):
if ax in item:
value_map[str(item[ax])].add(i)
axis_value_indices[ax] = value_map
# Find which axis values are already covered
covered = {}
for ax in axes:
covered[ax] = set()
for idx in selected_set:
if ax in feasible_set[idx]:
covered[ax].add(str(feasible_set[idx][ax]))
# Find uncovered axis values
uncovered_pairs = [] # (axis, value) pairs not yet covered
for ax in axes:
for val in axis_value_indices[ax]:
if val not in covered[ax]:
uncovered_pairs.append((ax, val))
if not uncovered_pairs:
return []
# Greedy set-cover: pick indices that cover the most uncovered (axis, value) pairs
additional = []
uncovered_set = set(range(len(uncovered_pairs)))
while uncovered_set and len(additional) < budget_remaining:
# For each candidate index, count how many uncovered pairs it covers
best_idx = -1
best_count = 0
best_covers = set()
# Build candidate pool: indices that appear in at least one uncovered pair's index set
candidates = set()
for ui in uncovered_set:
ax, val = uncovered_pairs[ui]
candidates.update(axis_value_indices[ax][val])
candidates -= selected_set
candidates -= set(additional)
if not candidates:
break
# Sample a subset to avoid O(N*U) when both are large
candidate_list = list(candidates)
if len(candidate_list) > 500:
rng.shuffle(candidate_list)
candidate_list = candidate_list[:500]
for ci in candidate_list:
item = feasible_set[ci]
covers = set()
for ui in uncovered_set:
ax, val = uncovered_pairs[ui]
if ax in item and str(item[ax]) == val:
covers.add(ui)
if len(covers) > best_count:
best_count = len(covers)
best_idx = ci
best_covers = covers
if best_idx < 0:
break
additional.append(best_idx)
uncovered_set -= best_covers
# Update covered sets
item = feasible_set[best_idx]
for ax in axes:
if ax in item:
covered[ax].add(str(item[ax]))
return additional