[rocm-libraries] ROCm/rocm-libraries#7311 (commit 79d8cae)

[CK Tile Engine] Daily tier sampling for tile engine GEMM  (#7311)

Summary
- Replace uniform random instance sampling (random.shuffle) with
scrambled Sobol + Latin Hypercube + maximin space-filling
sampling, per the Tile Engine Benchmark Sampling RFC
- Add op-weighted budget allocation via new
TILE_ENGINE_SAMPLING_TIER=daily CMake knob that auto-distributes 8,000
instances across
ops proportional to registered weights in op_weights.json
  - Emit chosen_instances.json manifests for reproducibility tracking
- Consolidate 5 copies of sampling logic into single _apply_sampling()
method on the base class
Jenkinsfile changes
Replace per-op -D *_MAX_INSTANCES=250 with single -D
TILE_ENGINE_SAMPLING_TIER=daily in gfx942/gfx950/gfx1201 stages. Budget
  auto-distributes (8000 total per GPU target).

---------

Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2026-05-21 02:17:42 -05:00
committed by GitHub
parent 275629fe34
commit c31fc4df52
23 changed files with 1367 additions and 11 deletions

11
Jenkinsfile vendored
View File

@@ -1870,9 +1870,10 @@ pipeline {
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
-D GROUPED_GEMM_DATATYPE="fp8;fp16" \
-D GROUPED_GEMM_LAYOUT="rcr;rrr;crr;ccr" .. && \
-D GROUPED_GEMM_LAYOUT="rcr;rrr;crr;ccr" \
-D TILE_ENGINE_SAMPLING_TIER=daily .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all benchmark_grouped_gemm_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json gemm_universal_results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.py . --problem-sizes "1024,1024,1024" --group-counts 8 --warmup 5 --repeat 5 --verbose --json grouped_gemm_results.json """
@@ -1901,7 +1902,8 @@ pipeline {
-D GEMM_MULTI_D_DATATYPE="fp16" \
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
-D TILE_ENGINE_SAMPLING_TIER=daily .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
@@ -1927,7 +1929,8 @@ pipeline {
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx1201" \
-D GEMM_UNIVERSAL_DATATYPE="fp16" \
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" .. && \
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
-D TILE_ENGINE_SAMPLING_TIER=daily .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}

View File

@@ -6,6 +6,11 @@ include_directories(BEFORE
${CMAKE_CURRENT_LIST_DIR}/ops
)
set(TILE_ENGINE_SAMPLING_TIER "" CACHE STRING
"Sampling tier: 'daily' (8000 budget) or integer budget (empty = no cap)")
set(TILE_ENGINE_SAMPLING_SEED "" CACHE STRING
"Override sampling seed (empty = daily rotation)")
add_subdirectory(ops/fmha EXCLUDE_FROM_ALL)
add_subdirectory(ops/gemm EXCLUDE_FROM_ALL)
add_subdirectory(ops/gemm_streamk EXCLUDE_FROM_ALL)

View File

@@ -1,7 +1,68 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Budget allocation when TILE_ENGINE_SAMPLING_TIER is set
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
# Map tier name to budget
if("${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "daily")
set(_te_budget 8000)
elseif("${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "weekly")
set(_te_budget 0)
else()
set(_te_budget ${TILE_ENGINE_SAMPLING_TIER})
endif()
if(_te_budget GREATER 0)
# Detect active ops from their DATATYPE variables
set(_active_ops "")
foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm)
string(TOUPPER ${_op} _OP_UPPER)
if(NOT "${${_OP_UPPER}_DATATYPE}" STREQUAL "")
list(APPEND _active_ops ${_op})
endif()
endforeach()
if(_active_ops)
string(REPLACE ";" "," _active_ops_csv "${_active_ops}")
set(_alloc_dir "${CMAKE_CURRENT_BINARY_DIR}/sampling_alloc")
file(MAKE_DIRECTORY ${_alloc_dir})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
"PYTHONPATH=${CMAKE_CURRENT_LIST_DIR}/../../"
${Python3_EXECUTABLE} -m sampling.allocate_budget
--total-budget ${_te_budget}
--active-ops "${_active_ops_csv}"
--output-dir ${_alloc_dir}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/../../
RESULT_VARIABLE _alloc_ret
OUTPUT_VARIABLE _alloc_output
ERROR_VARIABLE _alloc_error
)
if(NOT _alloc_ret EQUAL 0)
message(FATAL_ERROR "Budget allocation failed: ${_alloc_error}")
endif()
message(STATUS "Sampling budget allocation:\n${_alloc_output}")
# Read per-op allocations (only if not already overridden)
foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm)
string(TOUPPER ${_op} _OP_UPPER)
if("${${_OP_UPPER}_MAX_INSTANCES}" STREQUAL "")
if(EXISTS "${_alloc_dir}/${_op}_budget.txt")
file(READ "${_alloc_dir}/${_op}_budget.txt" _budget_val)
string(STRIP "${_budget_val}" _budget_val)
set(${_OP_UPPER}_MAX_INSTANCES ${_budget_val})
message(STATUS " ${_op}: ${_budget_val} instances (from budget allocation)")
endif()
else()
message(STATUS " ${_op}: ${${_OP_UPPER}_MAX_INSTANCES} instances (explicit override)")
endif()
endforeach()
endif()
endif()
endif()
add_subdirectory(gemm_universal EXCLUDE_FROM_ALL)
add_subdirectory(gemm_multi_d EXCLUDE_FROM_ALL)
add_subdirectory(gemm_preshuffle EXCLUDE_FROM_ALL)
add_subdirectory(grouped_gemm EXCLUDE_FROM_ALL)
add_subdirectory(grouped_gemm EXCLUDE_FROM_ALL)

View File

@@ -43,6 +43,10 @@ class GemmKernelBuilder:
datatype,
layout,
config_json=None,
max_instances=None,
seed=None,
tier=None,
manifest_path=None,
):
self.kernel_name_prefix = kernel_name_prefix
self.working_path = Path(working_path)
@@ -50,6 +54,10 @@ class GemmKernelBuilder:
self.datatype = datatype
self.layout = layout
self.config_json = config_json
self.max_instances = max_instances
self.seed = seed
self.tier = tier
self.manifest_path = manifest_path
# Create working directory if it doesn't exist
self.working_path.mkdir(parents=True, exist_ok=True)
@@ -59,6 +67,74 @@ class GemmKernelBuilder:
with open(config_json, "r") as f:
self.config = json.load(f)
def _apply_sampling(self, kernel_list):
"""Apply RFC Sobol+LHS+maximin sampling. Returns sampled subset."""
if self.max_instances is None or len(kernel_list) <= self.max_instances:
return kernel_list
import sys
sampling_parent = os.path.join(os.path.dirname(__file__), "..", "..")
if sampling_parent not in sys.path:
sys.path.insert(0, sampling_parent)
from sampling.sampler import sample_feasible_set
from sampling.seed import make_seed
from sampling.feasible_set import GEMM_AXES
effective_seed = make_seed(
self.seed, self.gpu_target, self.datatype, self.layout
)
flat_items = []
for k in kernel_list:
flat = dict(k["tile_config"])
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = k[
"trait_combo"
]
flat.update(
{
"pipeline": pipeline,
"epilogue": epilogue,
"scheduler": scheduler,
"pad_m": pad_m,
"pad_n": pad_n,
"pad_k": pad_k,
"persistent": persistent,
}
)
flat_items.append(flat)
selected, method, selected_indices = sample_feasible_set(
flat_items,
self.max_instances,
effective_seed,
GEMM_AXES,
)
kernel_list = [kernel_list[i] for i in selected_indices]
if self.manifest_path:
from sampling.manifest import write_manifest
write_manifest(
selected,
self.manifest_path,
self.kernel_name_prefix,
self.datatype,
self.layout,
self.gpu_target,
effective_seed,
self.tier or "daily",
method,
)
print(
f"Sampled {len(kernel_list)} from feasible set "
f"(budget={self.max_instances}, seed={effective_seed}, method={method})"
)
return kernel_list
def _list_kernels(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
@@ -96,6 +172,9 @@ class GemmKernelBuilder:
}
)
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
kernel_list = self._apply_sampling(kernel_list)
# Write kernel count
with open(
self.working_path / f"{self.kernel_name_prefix}_kernel_count.txt", "w"

View File

@@ -5,6 +5,7 @@ set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi
set(GEMM_MULTI_D_LAYOUT "rcrr;rrrr;crrr;ccrr" CACHE STRING "List of layout for GEMM Multi D (semicolon-separated)")
set(GEMM_MULTI_D_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
set(GEMM_MULTI_D_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
option(ENABLE_CCACHE_GEMM_MULTI_D "Enable ccache for GEMM Multi D ops compilation" OFF)
@@ -175,6 +176,19 @@ function(build_individual_gemm_multi_d_targets datatype layout)
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# Build optional args for instance builder
set(extra_list_args "")
if(NOT "${GEMM_MULTI_D_MAX_INSTANCES}" STREQUAL "")
list(APPEND extra_list_args --max-instances ${GEMM_MULTI_D_MAX_INSTANCES})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
list(APPEND extra_list_args --manifest-path ${working_path})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
endif()
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
@@ -185,7 +199,8 @@ function(build_individual_gemm_multi_d_targets datatype layout)
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${json_blob}
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
--list_kernels
--list_kernels
${extra_list_args}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
@@ -305,6 +320,20 @@ else()
add_custom_target(benchmark_gemm_multi_d_${scheduler}_scheduler)
endforeach()
# Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that
# sampling fires per-combo rather than being a single cap larger than any combo's
# feasible set (which would make sampling a no-op for most combos).
if(NOT "${GEMM_MULTI_D_MAX_INSTANCES}" STREQUAL "")
list(LENGTH GEMM_MULTI_D_DATATYPE _gmd_n_dt)
list(LENGTH GEMM_MULTI_D_LAYOUT _gmd_n_lay)
math(EXPR _gmd_n_combos "${_gmd_n_dt} * ${_gmd_n_lay}")
if(_gmd_n_combos GREATER 0)
math(EXPR GEMM_MULTI_D_MAX_INSTANCES
"${GEMM_MULTI_D_MAX_INSTANCES} / ${_gmd_n_combos}")
message(STATUS " gemm_multi_d: per-combo budget = ${GEMM_MULTI_D_MAX_INSTANCES} (${_gmd_n_combos} combos)")
endif()
endif()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)

View File

@@ -37,9 +37,22 @@ class GemmMultiDKernelBuilder(GemmKernelBuilder):
layout,
elementwise_function,
config_json=None,
max_instances=None,
seed=None,
tier=None,
manifest_path=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
max_instances=max_instances,
seed=seed,
tier=tier,
manifest_path=manifest_path,
)
self.elementwise_function = elementwise_function
@@ -71,6 +84,15 @@ class GemmMultiDKernelBuilder(GemmKernelBuilder):
)
)
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
if self.max_instances is not None and len(work_items) > self.max_instances:
kernel_dicts = [
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
for item in work_items
]
sampled = self._apply_sampling(kernel_dicts)
work_items = [k["_work_item"] for k in sampled]
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
@@ -223,6 +245,28 @@ def main():
action="store_true",
help="List kernel configurations without generating files",
)
parser.add_argument(
"--max-instances",
type=int,
default=None,
help="Cap on number of kernel instances per (dtype, layout) combo",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
)
parser.add_argument(
"--tier",
default=None,
help="Sampling tier (daily/weekly)",
)
parser.add_argument(
"--manifest-path",
default=None,
help="Directory for chosen_instances.json",
)
args = parser.parse_args()
@@ -271,6 +315,10 @@ def main():
args.layout,
args.elementwise_function,
args.config_json,
max_instances=args.max_instances,
seed=args.seed,
tier=args.tier,
manifest_path=args.manifest_path,
)
if args.list_kernels:

View File

@@ -4,6 +4,7 @@
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)")
set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
set(GEMM_PRESHUFFLE_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF)
# Store the directory path for use in functions
@@ -163,6 +164,19 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
# Build optional args for instance builder
set(extra_list_args "")
if(NOT "${GEMM_PRESHUFFLE_MAX_INSTANCES}" STREQUAL "")
list(APPEND extra_list_args --max-instances ${GEMM_PRESHUFFLE_MAX_INSTANCES})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
list(APPEND extra_list_args --manifest-path ${working_path})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
endif()
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
message(VERBOSE " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
@@ -174,6 +188,7 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
--layout ${layout}
--config_json ${json_blob}
--list_kernels
${extra_list_args}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
@@ -293,6 +308,20 @@ else()
add_custom_target(benchmark_gemm_preshuffle_${scheduler}_scheduler)
endforeach()
# Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that
# sampling fires per-combo rather than being a single cap larger than any combo's
# feasible set (which would make sampling a no-op for most combos).
if(NOT "${GEMM_PRESHUFFLE_MAX_INSTANCES}" STREQUAL "")
list(LENGTH GEMM_PRESHUFFLE_DATATYPE _gp_n_dt)
list(LENGTH GEMM_PRESHUFFLE_LAYOUT _gp_n_lay)
math(EXPR _gp_n_combos "${_gp_n_dt} * ${_gp_n_lay}")
if(_gp_n_combos GREATER 0)
math(EXPR GEMM_PRESHUFFLE_MAX_INSTANCES
"${GEMM_PRESHUFFLE_MAX_INSTANCES} / ${_gp_n_combos}")
message(STATUS " gemm_preshuffle: per-combo budget = ${GEMM_PRESHUFFLE_MAX_INSTANCES} (${_gp_n_combos} combos)")
endif()
endif()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)

View File

@@ -36,9 +36,22 @@ class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
datatype,
layout,
config_json=None,
max_instances=None,
seed=None,
tier=None,
manifest_path=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
max_instances=max_instances,
seed=seed,
tier=tier,
manifest_path=manifest_path,
)
def _generate_all_individual(self, num_workers=None):
@@ -68,6 +81,15 @@ class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
)
)
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
if self.max_instances is not None and len(work_items) > self.max_instances:
kernel_dicts = [
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
for item in work_items
]
sampled = self._apply_sampling(kernel_dicts)
work_items = [k["_work_item"] for k in sampled]
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
@@ -212,6 +234,28 @@ def main():
action="store_true",
help="List kernel configurations without generating files",
)
parser.add_argument(
"--max-instances",
type=int,
default=None,
help="Cap on number of kernel instances per (dtype, layout) combo",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
)
parser.add_argument(
"--tier",
default=None,
help="Sampling tier (daily/weekly)",
)
parser.add_argument(
"--manifest-path",
default=None,
help="Directory for chosen_instances.json",
)
args = parser.parse_args()
@@ -239,6 +283,10 @@ def main():
args.datatype,
args.layout,
args.config_json,
max_instances=args.max_instances,
seed=args.seed,
tier=args.tier,
manifest_path=args.manifest_path,
)
if args.list_kernels:

View File

@@ -4,6 +4,7 @@
set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)")
set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)")
set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
set(GEMM_UNIVERSAL_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF)
# Store the directory path for use in functions
@@ -171,6 +172,19 @@ function(build_individual_gemm_universal_targets datatype layout)
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# Build optional args for instance builder
set(extra_list_args "")
if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "")
list(APPEND extra_list_args --max-instances ${GEMM_UNIVERSAL_MAX_INSTANCES})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
list(APPEND extra_list_args --manifest-path ${working_path})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
endif()
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
@@ -180,7 +194,8 @@ function(build_individual_gemm_universal_targets datatype layout)
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
--list_kernels
--list_kernels
${extra_list_args}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
@@ -300,6 +315,20 @@ else()
add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler)
endforeach()
# Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that
# sampling fires per-combo rather than being a single cap larger than any combo's
# feasible set (which would make sampling a no-op for most combos).
if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "")
list(LENGTH GEMM_UNIVERSAL_DATATYPE _gu_n_dt)
list(LENGTH GEMM_UNIVERSAL_LAYOUT _gu_n_lay)
math(EXPR _gu_n_combos "${_gu_n_dt} * ${_gu_n_lay}")
if(_gu_n_combos GREATER 0)
math(EXPR GEMM_UNIVERSAL_MAX_INSTANCES
"${GEMM_UNIVERSAL_MAX_INSTANCES} / ${_gu_n_combos}")
message(STATUS " gemm_universal: per-combo budget = ${GEMM_UNIVERSAL_MAX_INSTANCES} (${_gu_n_combos} combos)")
endif()
endif()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)

View File

@@ -36,9 +36,22 @@ class GemmUniversalKernelBuilder(GemmKernelBuilder):
datatype,
layout,
config_json=None,
max_instances=None,
seed=None,
tier=None,
manifest_path=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
max_instances=max_instances,
seed=seed,
tier=tier,
manifest_path=manifest_path,
)
def _generate_all_individual(self, num_workers=None):
@@ -67,6 +80,16 @@ class GemmUniversalKernelBuilder(GemmKernelBuilder):
self.config_json,
)
)
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
if self.max_instances is not None and len(work_items) > self.max_instances:
kernel_dicts = [
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
for item in work_items
]
sampled = self._apply_sampling(kernel_dicts)
work_items = [k["_work_item"] for k in sampled]
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
@@ -210,6 +233,28 @@ def main():
action="store_true",
help="List kernel configurations without generating files",
)
parser.add_argument(
"--max-instances",
type=int,
default=None,
help="Cap on number of kernel instances per (dtype, layout) combo",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
)
parser.add_argument(
"--tier",
default=None,
help="Sampling tier (daily/weekly)",
)
parser.add_argument(
"--manifest-path",
default=None,
help="Directory for chosen_instances.json",
)
args = parser.parse_args()
@@ -236,6 +281,10 @@ def main():
args.datatype,
args.layout,
args.config_json,
max_instances=args.max_instances,
seed=args.seed,
tier=args.tier,
manifest_path=args.manifest_path,
)
if args.list_kernels:

View File

@@ -4,6 +4,7 @@
set(GROUPED_GEMM_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for Grouped GEMM (semicolon-separated)")
set(GROUPED_GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for Grouped GEMM (semicolon-separated)")
set(GROUPED_GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
set(GROUPED_GEMM_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
option(ENABLE_CCACHE_GROUPED_GEMM "Enable ccache for Grouped GEMM ops compilation" OFF)
# Store the directory path for use in functions
@@ -171,6 +172,19 @@ function(build_individual_grouped_gemm_targets datatype layout)
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# Build optional args for instance builder
set(extra_list_args "")
if(NOT "${GROUPED_GEMM_MAX_INSTANCES}" STREQUAL "")
list(APPEND extra_list_args --max-instances ${GROUPED_GEMM_MAX_INSTANCES})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
list(APPEND extra_list_args --manifest-path ${working_path})
endif()
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
endif()
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
@@ -181,6 +195,7 @@ function(build_individual_grouped_gemm_targets datatype layout)
--config_json ${json_blob}
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels
${extra_list_args}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output

View File

@@ -36,9 +36,22 @@ class GroupedGemmKernelBuilder(GemmKernelBuilder):
datatype,
layout,
config_json=None,
max_instances=None,
seed=None,
tier=None,
manifest_path=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
max_instances=max_instances,
seed=seed,
tier=tier,
manifest_path=manifest_path,
)
def _generate_all_individual(self, num_workers=None):
@@ -67,6 +80,16 @@ class GroupedGemmKernelBuilder(GemmKernelBuilder):
self.config_json,
)
)
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
if self.max_instances is not None and len(work_items) > self.max_instances:
kernel_dicts = [
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
for item in work_items
]
sampled = self._apply_sampling(kernel_dicts)
work_items = [k["_work_item"] for k in sampled]
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
@@ -209,6 +232,28 @@ def main():
action="store_true",
help="List kernel configurations without generating files",
)
parser.add_argument(
"--max-instances",
type=int,
default=None,
help="Cap on number of kernel instances per (dtype, layout) combo",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
)
parser.add_argument(
"--tier",
default=None,
help="Sampling tier (daily/weekly)",
)
parser.add_argument(
"--manifest-path",
default=None,
help="Directory for chosen_instances.json",
)
args = parser.parse_args()
@@ -244,6 +289,10 @@ def main():
args.datatype,
args.layout,
args.config_json,
max_instances=args.max_instances,
seed=args.seed,
tier=args.tier,
manifest_path=args.manifest_path,
)
if args.list_kernels:

View File

@@ -0,0 +1,10 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from sampling.sampler import sample_feasible_set as sample_feasible_set
from sampling.seed import make_seed as make_seed
from sampling.budget import allocate_budget as allocate_budget
from sampling.budget import load_op_weights as load_op_weights
from sampling.manifest import write_manifest as write_manifest
from sampling.feasible_set import GEMM_AXES as GEMM_AXES
from sampling.feasible_set import normalize_axis_values as normalize_axis_values

View File

@@ -0,0 +1,102 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""CLI entry point for budget allocation, called by CMake at configure time.
Usage:
python -m sampling.allocate_budget \
--total-budget 8000 \
--active-ops "gemm_universal,gemm_multi_d,gemm_preshuffle,grouped_gemm" \
--output-dir /build/sampling_alloc \
[--weights-file /path/to/op_weights.json]
Writes per-op budget files (e.g. gemm_universal_budget.txt) containing a single integer.
"""
import argparse
import json
import sys
from pathlib import Path
def _setup_path():
_this_dir = Path(__file__).resolve().parent
_tile_engine_dir = _this_dir.parent
if str(_tile_engine_dir) not in sys.path:
sys.path.insert(0, str(_tile_engine_dir))
_setup_path()
from sampling.budget import allocate_budget # noqa: E402
from sampling.budget import load_op_weights # noqa: E402
def main():
parser = argparse.ArgumentParser(description="Allocate instance budget across ops")
parser.add_argument(
"--total-budget",
type=int,
required=True,
help="Total instance budget (e.g. 8000 for daily tier)",
)
parser.add_argument(
"--active-ops",
type=str,
required=True,
help="Comma or semicolon-separated list of active op names",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Directory to write per-op budget files",
)
parser.add_argument(
"--weights-file",
type=str,
default=None,
help="Path to op_weights.json (default: built-in)",
)
args = parser.parse_args()
# Parse active ops (support both comma and semicolon separators)
active_ops = [
op.strip() for op in args.active_ops.replace(";", ",").split(",") if op.strip()
]
if not active_ops:
print("ERROR: No active ops specified", file=sys.stderr)
sys.exit(1)
weights = load_op_weights(args.weights_file)
alloc = allocate_budget(args.total_budget, active_ops, weights, strict=True)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Write per-op budget files
for op, budget in alloc.items():
budget_file = output_dir / f"{op}_budget.txt"
budget_file.write_text(str(budget))
# Write combined allocation metadata
meta = {
"total_budget": args.total_budget,
"active_ops": active_ops,
"allocations": alloc,
"weights_used": {op: weights.get(op, 0.0) for op in active_ops},
}
meta_file = output_dir / "sampling_allocations.json"
with open(meta_file, "w") as f:
json.dump(meta, f, indent=2)
# Print summary
print(f"Budget allocation (total={args.total_budget}):")
for op, budget in sorted(alloc.items()):
print(f" {op}: {budget}")
print(f" Sum: {sum(alloc.values())}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,78 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Op-weighted budget allocation.
Distributes a total instance budget across active ops proportional to
their registered weights. Implements RFC section 3.4.
"""
import json
from pathlib import Path
_DEFAULT_WEIGHTS_FILE = Path(__file__).parent / "op_weights.json"
def load_op_weights(weights_file=None):
"""Load op weights from JSON file.
Returns:
Dict mapping op name to weight (float).
"""
path = Path(weights_file) if weights_file else _DEFAULT_WEIGHTS_FILE
with open(path) as f:
data = json.load(f)
return data["weights"]
def allocate_budget(total_budget, active_ops, weights, strict=True):
"""Distribute total_budget across active_ops proportional to weights.
Args:
total_budget: Total instance budget (e.g. 8000).
active_ops: List of active op names.
weights: Dict mapping op name to weight.
strict: If True, raise ValueError for unweighted active ops.
Returns:
Dict mapping op name to allocated budget (int).
Sum of allocations exactly equals total_budget.
"""
if not active_ops:
return {}
# Check all active ops have weights
missing = [op for op in active_ops if op not in weights]
if missing and strict:
raise ValueError(
f"Active ops without registered weights: {missing}. "
f"Add them to op_weights.json before running with sampling enabled."
)
# Compute weight sum for active ops only
active_weights = {op: weights.get(op, 0.0) for op in active_ops}
total_weight = sum(active_weights.values())
if total_weight <= 0:
# Equal distribution fallback
per_op = total_budget // len(active_ops)
alloc = {op: per_op for op in active_ops}
remainder = total_budget - sum(alloc.values())
for i, op in enumerate(active_ops):
if i < remainder:
alloc[op] += 1
return alloc
# Proportional allocation with floor
alloc = {}
for op in active_ops:
alloc[op] = int(total_budget * active_weights[op] / total_weight)
# Distribute remainder to highest-weight ops
remainder = total_budget - sum(alloc.values())
sorted_ops = sorted(active_ops, key=lambda op: active_weights[op], reverse=True)
for i in range(remainder):
alloc[sorted_ops[i % len(sorted_ops)]] += 1
return alloc

View File

@@ -0,0 +1,82 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
GEMM_AXES = [
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
"pipeline",
"epilogue",
"scheduler",
"pad_m",
"pad_n",
"pad_k",
"persistent",
]
CATEGORICAL_AXES = {
"pipeline",
"epilogue",
"scheduler",
"pad_m",
"pad_n",
"pad_k",
"persistent",
}
def normalize_axis_values(feasible_set, axes=None):
"""Compute normalization metadata for each axis.
Returns dict mapping axis name to:
- For numeric axes: {"type": "numeric", "min": v, "max": v, "range": v}
- For categorical axes: {"type": "categorical", "values": sorted list, "map": value->index}
"""
if axes is None:
axes = GEMM_AXES
meta = {}
for ax in axes:
values = [item[ax] for item in feasible_set if ax in item]
if not values:
continue
if ax in CATEGORICAL_AXES:
unique = sorted(set(str(v) for v in values))
meta[ax] = {
"type": "categorical",
"values": unique,
"map": {v: i for i, v in enumerate(unique)},
"count": len(unique),
}
else:
num_values = [float(v) for v in values]
mn, mx = min(num_values), max(num_values)
meta[ax] = {
"type": "numeric",
"min": mn,
"max": mx,
"range": mx - mn if mx != mn else 1.0,
}
return meta
def normalize_point(item, axes, meta):
"""Normalize a single point to [0, 1] per axis."""
coords = []
for ax in axes:
if ax not in meta or ax not in item:
coords.append(0.0)
continue
m = meta[ax]
if m["type"] == "numeric":
coords.append((float(item[ax]) - m["min"]) / m["range"])
else:
coords.append(m["map"].get(str(item[ax]), 0) / max(m["count"] - 1, 1))
return coords

108
tile_engine/sampling/lhs.py Normal file
View File

@@ -0,0 +1,108 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Latin Hypercube Sampling padding for marginal coverage.
Ensures every distinct value on every parameter axis appears at least once
in the selected sample, using a greedy set-cover heuristic.
"""
from collections import defaultdict
def lhs_pad(selected_indices, feasible_set, axes, budget_remaining, rng):
"""Add indices to guarantee marginal coverage on all axes.
Args:
selected_indices: Already-selected feasible-set indices.
feasible_set: Full list of parameter dicts.
axes: List of axis names to ensure coverage for.
budget_remaining: Max additional indices to add.
rng: random.Random instance for tie-breaking.
Returns:
List of additional indices to include.
"""
if budget_remaining <= 0:
return []
selected_set = set(selected_indices)
# Build per-axis coverage maps: axis -> value -> set of feasible indices with that value
axis_value_indices = {}
for ax in axes:
value_map = defaultdict(set)
for i, item in enumerate(feasible_set):
if ax in item:
value_map[str(item[ax])].add(i)
axis_value_indices[ax] = value_map
# Find which axis values are already covered
covered = {}
for ax in axes:
covered[ax] = set()
for idx in selected_set:
if ax in feasible_set[idx]:
covered[ax].add(str(feasible_set[idx][ax]))
# Find uncovered axis values
uncovered_pairs = [] # (axis, value) pairs not yet covered
for ax in axes:
for val in axis_value_indices[ax]:
if val not in covered[ax]:
uncovered_pairs.append((ax, val))
if not uncovered_pairs:
return []
# Greedy set-cover: pick indices that cover the most uncovered (axis, value) pairs
additional = []
uncovered_set = set(range(len(uncovered_pairs)))
while uncovered_set and len(additional) < budget_remaining:
# For each candidate index, count how many uncovered pairs it covers
best_idx = -1
best_count = 0
best_covers = set()
# Build candidate pool: indices that appear in at least one uncovered pair's index set
candidates = set()
for ui in uncovered_set:
ax, val = uncovered_pairs[ui]
candidates.update(axis_value_indices[ax][val])
candidates -= selected_set
candidates -= set(additional)
if not candidates:
break
# Sample a subset to avoid O(N*U) when both are large
candidate_list = list(candidates)
if len(candidate_list) > 500:
rng.shuffle(candidate_list)
candidate_list = candidate_list[:500]
for ci in candidate_list:
item = feasible_set[ci]
covers = set()
for ui in uncovered_set:
ax, val = uncovered_pairs[ui]
if ax in item and str(item[ax]) == val:
covers.add(ui)
if len(covers) > best_count:
best_count = len(covers)
best_idx = ci
best_covers = covers
if best_idx < 0:
break
additional.append(best_idx)
uncovered_set -= best_covers
# Update covered sets
item = feasible_set[best_idx]
for ax in axes:
if ax in item:
covered[ax].add(str(item[ax]))
return additional

View File

@@ -0,0 +1,124 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Manifest writer for chosen_instances.json.
Each tier run emits a manifest recording the selected instances, their
parameters, the sampling method, seed, and metadata for reproducibility.
"""
import hashlib
import json
import subprocess
from pathlib import Path
def _instance_hash(params):
"""Compute a 16-hex-char fingerprint of tile+trait parameters."""
canonical = json.dumps(params, sort_keys=True, default=str)
return hashlib.sha256(canonical.encode()).hexdigest()[:16]
def _get_git_sha():
"""Get current git HEAD SHA, or 'unknown' if not in a git repo."""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
return result.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
return "unknown"
def write_manifest(
instances, output_path, op, datatype, layout, gpu_target, seed, tier, sampler_method
):
"""Write chosen_instances.json manifest.
Args:
instances: List of parameter dicts (flat tile+trait keys).
output_path: Directory to write the manifest into.
op: Op name (e.g. "gemm_universal").
datatype: Data type string (e.g. "fp16").
layout: Layout string (e.g. "rcr").
gpu_target: GPU target (e.g. "gfx942").
seed: Integer seed used for sampling.
tier: Tier name (e.g. "daily").
sampler_method: Sampling method string (e.g. "sobol+lhs+maximin").
Returns:
Path to the written manifest file.
"""
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
git_sha = _get_git_sha()
manifest_entries = []
for params in instances:
entry = {
"instance_hash": _instance_hash(params),
"op": op,
"dtype": datatype,
"layout": layout,
"arch": gpu_target,
}
# Add tile parameters
for key in [
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
]:
if key in params:
entry[key] = params[key]
# Add trait parameters
for key in [
"pipeline",
"epilogue",
"scheduler",
"pad_m",
"pad_n",
"pad_k",
"persistent",
]:
if key in params:
entry[key] = params[key]
entry["sampler_method"] = sampler_method
entry["seed"] = seed
entry["tier"] = tier
entry["git_sha"] = git_sha
manifest_entries.append(entry)
manifest = {
"version": "1.0",
"op": op,
"dtype": datatype,
"layout": layout,
"arch": gpu_target,
"seed": seed,
"tier": tier,
"sampler_method": sampler_method,
"git_sha": git_sha,
"instance_count": len(manifest_entries),
"instances": manifest_entries,
}
manifest_file = output_dir / f"chosen_instances_{op}_{datatype}_{layout}.json"
with open(manifest_file, "w") as f:
json.dump(manifest, f, indent=2, default=str)
return manifest_file

View File

@@ -0,0 +1,135 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Maximin simulated annealing post-pass.
Improves minimum pairwise distance in the selected subset by swapping
points with unselected candidates. RFC specifies 200 iterations.
Uses a cached pairwise distance matrix to avoid O(n^2) recomputation
per iteration. Updates are O(n) per swap.
"""
import math
def _manhattan_distance(a, b):
return sum(abs(x - y) for x, y in zip(a, b))
def maximin_anneal(
selected_indices, feasible_set, normalized_coords, iterations=200, rng=None
):
"""Improve minimum pairwise distance via simulated annealing.
Args:
selected_indices: List of indices into feasible_set (will be modified in-place).
feasible_set: Full list of parameter dicts (not modified).
normalized_coords: List of normalized coordinate vectors, one per feasible-set item.
iterations: Number of SA iterations (default 200 per RFC).
rng: random.Random instance.
Returns:
Modified selected_indices list.
"""
import random as random_mod
if rng is None:
rng = random_mod.Random(42)
n = len(selected_indices)
if n < 3:
return selected_indices
all_indices = set(range(len(feasible_set)))
selected_set = set(selected_indices)
unselected = list(all_indices - selected_set)
if not unselected:
return selected_indices
sel_coords = [normalized_coords[i] for i in selected_indices]
# Build per-point minimum distance cache: for each point, store its
# minimum distance to any other selected point and the index of that neighbor
min_dists = [float("inf")] * n
min_neighbors = [0] * n
for i in range(n):
for j in range(i + 1, n):
d = _manhattan_distance(sel_coords[i], sel_coords[j])
if d < min_dists[i]:
min_dists[i] = d
min_neighbors[i] = j
if d < min_dists[j]:
min_dists[j] = d
min_neighbors[j] = i
for iteration in range(iterations):
t = 1.0 - (iteration / iterations) * 0.99
# Find the point with the globally smallest min_dist (half of closest pair)
victim_pos = min(range(n), key=lambda i: min_dists[i])
old_min_dist = min_dists[victim_pos]
victim_idx = selected_indices[victim_pos]
# Pick a random unselected candidate
candidate_pos = rng.randint(0, len(unselected) - 1)
candidate_idx = unselected[candidate_pos]
candidate_coord = normalized_coords[candidate_idx]
# Compute candidate's min distance to all other selected points
new_cand_min = float("inf")
for k in range(n):
if k == victim_pos:
continue
d = _manhattan_distance(candidate_coord, sel_coords[k])
if d < new_cand_min:
new_cand_min = d
delta = new_cand_min - old_min_dist
accept = delta > 0
if not accept and t > 0.001:
try:
prob = math.exp(delta / t)
accept = rng.random() < prob
except (OverflowError, ValueError):
accept = False
if accept:
unselected[candidate_pos] = victim_idx
selected_indices[victim_pos] = candidate_idx
sel_coords[victim_pos] = candidate_coord
# Recompute min_dists for the swapped position and any point
# whose nearest neighbor was the victim
for k in range(n):
if k == victim_pos:
# Recompute for the new point
min_dists[k] = float("inf")
min_neighbors[k] = 0
for j in range(n):
if j == k:
continue
d = _manhattan_distance(sel_coords[k], sel_coords[j])
if d < min_dists[k]:
min_dists[k] = d
min_neighbors[k] = j
elif min_neighbors[k] == victim_pos:
# Nearest neighbor was replaced — full recompute for this point
min_dists[k] = float("inf")
for j in range(n):
if j == k:
continue
d = _manhattan_distance(sel_coords[k], sel_coords[j])
if d < min_dists[k]:
min_dists[k] = d
min_neighbors[k] = j
else:
# Check if the new point is closer than current minimum
d = _manhattan_distance(sel_coords[k], sel_coords[victim_pos])
if d < min_dists[k]:
min_dists[k] = d
min_neighbors[k] = victim_pos
return selected_indices

View File

@@ -0,0 +1,10 @@
{
"version": "1.0",
"description": "Op weights for Daily Tier budget allocation (RFC section 3.4)",
"weights": {
"gemm_universal": 0.35,
"gemm_preshuffle": 0.30,
"gemm_multi_d": 0.20,
"grouped_gemm": 0.15
}
}

View File

@@ -0,0 +1,103 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Main sampling orchestrator: Sobol -> LHS pad -> maximin refine.
Implements the Daily Tier sampling pipeline from the RFC:
1. Sobol base draw (90% of budget)
2. LHS padding for marginal axis coverage (10% reserve)
3. Maximin simulated annealing post-pass (200 iterations)
"""
import random
from sampling.sobol import SobolSequence
from sampling.lhs import lhs_pad
from sampling.maximin import maximin_anneal
from sampling.feasible_set import GEMM_AXES, normalize_axis_values, normalize_point
def sample_feasible_set(feasible_set, budget, seed, axes=None, maximin_iterations=200):
"""Select `budget` items from `feasible_set` using Sobol + LHS + maximin.
Args:
feasible_set: List of parameter dicts (each with tile/trait keys).
budget: Maximum number of items to select.
seed: Integer seed for deterministic selection.
axes: List of axis names (defaults to GEMM_AXES).
maximin_iterations: SA iterations for maximin pass (default 200).
Returns:
Tuple of (selected_items, sampler_method_string).
"""
n = len(feasible_set)
if axes is None:
axes = GEMM_AXES
if budget >= n:
return list(feasible_set), "full", list(range(n))
if n == 0:
return [], "empty", []
rng = random.Random(seed)
# Phase 1: Sobol base selection (fill as much as possible from Sobol)
sobol = SobolSequence(d=1, scramble=True, seed=seed)
raw_points = sobol.generate(min(budget * 4, n * 2))
selected_indices = []
seen = set()
for pt in raw_points:
idx = min(int(pt[0] * n), n - 1)
if idx not in seen:
seen.add(idx)
selected_indices.append(idx)
if len(selected_indices) >= budget:
break
# If Sobol didn't produce enough unique points, fill with RNG
if len(selected_indices) < budget:
remaining = list(set(range(n)) - seen)
rng.shuffle(remaining)
for idx in remaining:
if len(selected_indices) >= budget:
break
selected_indices.append(idx)
seen.add(idx)
# Phase 2: LHS padding — swap in points that cover uncovered axis values
available_axes = [ax for ax in axes if any(ax in item for item in feasible_set)]
lhs_additions = lhs_pad(
selected_indices,
feasible_set,
available_axes,
max(0, budget - len(selected_indices)),
rng,
)
for idx in lhs_additions:
if idx not in seen:
seen.add(idx)
selected_indices.append(idx)
# Trim to budget
if len(selected_indices) > budget:
selected_indices = selected_indices[:budget]
# Phase 3: Maximin simulated annealing
meta = normalize_axis_values(feasible_set, available_axes)
all_coords = [normalize_point(item, available_axes, meta) for item in feasible_set]
selected_indices = maximin_anneal(
selected_indices,
feasible_set,
all_coords,
iterations=maximin_iterations,
rng=rng,
)
# Sort by original index for deterministic output order
selected_indices.sort()
selected_items = [feasible_set[i] for i in selected_indices]
return selected_items, "sobol+lhs+maximin", selected_indices

View File

@@ -0,0 +1,24 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import hashlib
import datetime
def daily_seed(date=None, extra=""):
"""Return sha256(YYYY-MM-DD[:extra]) & 0xFFFFFFFF."""
if date is None:
date = datetime.date.today()
material = date.isoformat()
if extra:
material += f":{extra}"
return int(hashlib.sha256(material.encode()).hexdigest(), 16) & 0xFFFFFFFF
def make_seed(explicit_seed=None, gpu_target="", datatype="", layout=""):
"""If explicit_seed is given, return it. Otherwise compute daily seed
with gpu_target:datatype:layout as extra material."""
if explicit_seed is not None:
return explicit_seed
extra = ":".join(filter(None, [gpu_target, datatype, layout]))
return daily_seed(extra=extra)

View File

@@ -0,0 +1,136 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Pure-Python scrambled Sobol sequence generator with scipy fallback.
Uses Joe-Kuo direction numbers for up to 21 dimensions. At daily-tier
budgets (~2000-3000 points from ~10,000-25,000 feasible), pure Python
runs well under a second.
"""
import random
# Joe-Kuo direction numbers for dimensions 2-21.
# Each row: (degree_of_primitive_poly, coefficients_of_poly, [initial_direction_numbers])
# Dimension 1 uses the Van der Corput sequence (bit-reversal).
# Source: Joe & Kuo (2010), https://web.maths.unsw.edu.au/~fkuo/sobol/joe-kuo-old.1111
_JOE_KUO_PARAMS = [
(1, 0, [1]),
(2, 1, [1, 1]),
(3, 1, [1, 3, 1]),
(3, 2, [1, 3, 3]),
(4, 1, [1, 1, 1, 1]), # dim 6
(4, 4, [1, 1, 3, 3]),
(5, 2, [1, 3, 5, 13, 7]), # dim 8
(5, 4, [1, 1, 5, 5, 17]),
(5, 7, [1, 1, 5, 5, 5]), # dim 10
(5, 11, [1, 1, 7, 11, 19]),
(5, 13, [1, 1, 5, 1, 1]),
(5, 14, [1, 1, 1, 3, 11]),
(6, 1, [1, 3, 5, 5, 31, 45]), # dim 14
(6, 13, [1, 3, 3, 9, 7, 25]),
(6, 16, [1, 3, 1, 15, 17, 63]),
(7, 19, [1, 1, 5, 13, 11, 3, 15]), # dim 17
(7, 22, [1, 3, 1, 7, 3, 23, 79]),
(7, 25, [1, 3, 7, 9, 31, 29, 17]),
(7, 37, [1, 1, 3, 15, 29, 15, 41]), # dim 20
(7, 41, [1, 3, 1, 7, 3, 23, 79]), # dim 21 (repeat of 18 for safety)
]
_BITS = 32
def _compute_direction_numbers(dim_index):
"""Compute 32-bit direction numbers for a given dimension (0-indexed, dim 0 = Van der Corput)."""
if dim_index == 0:
return [1 << (_BITS - 1 - i) for i in range(_BITS)]
params = _JOE_KUO_PARAMS[dim_index - 1]
s = params[0]
a = params[1]
m_init = params[2]
v = [0] * _BITS
for i in range(min(s, _BITS)):
if i < len(m_init):
v[i] = m_init[i] << (_BITS - 1 - i)
else:
v[i] = 1 << (_BITS - 1 - i)
for i in range(s, _BITS):
v[i] = v[i - s] ^ (v[i - s] >> s)
for j in range(1, s):
if (a >> (s - 1 - j)) & 1:
v[i] ^= v[i - j]
return v
class SobolSequence:
"""Scrambled Sobol sequence generator.
Falls back to scipy.stats.qmc.Sobol when available for better scrambling quality.
"""
def __init__(self, d, scramble=True, seed=0):
self.d = d
self.scramble = scramble
self.seed = seed
self._use_scipy = False
if d > 21:
raise ValueError(f"Sobol dimension {d} exceeds maximum 21")
try:
from scipy.stats.qmc import Sobol as ScipySobol
self._scipy_sobol = ScipySobol(d=d, scramble=scramble, seed=seed)
self._use_scipy = True
except ImportError:
self._direction_numbers = [_compute_direction_numbers(i) for i in range(d)]
self._scramble_shifts = []
if scramble:
rng = random.Random(seed)
self._scramble_shifts = [
rng.randint(0, (1 << _BITS) - 1) for _ in range(d)
]
def generate(self, n):
"""Generate n points in [0, 1)^d."""
if self._use_scipy:
import math
m = max(1, math.ceil(math.log2(n))) if n > 0 else 0
points = self._scipy_sobol.random_base2(m)
return points[:n].tolist()
points = []
x = [0] * self.d
for i in range(n):
if i == 0:
point = [0.0] * self.d
if self.scramble:
for dim in range(self.d):
x[dim] = self._scramble_shifts[dim]
point[dim] = x[dim] / (1 << _BITS)
points.append(point)
else:
c = _rightmost_zero_bit(i - 1)
point = [0.0] * self.d
for dim in range(self.d):
if c < _BITS:
x[dim] ^= self._direction_numbers[dim][c]
point[dim] = x[dim] / (1 << _BITS)
points.append(point)
return points
def _rightmost_zero_bit(n):
"""Find position of rightmost zero bit."""
pos = 0
while n & 1:
n >>= 1
pos += 1
return pos