mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#7311 (commit 79d8cae)
[CK Tile Engine] Daily tier sampling for tile engine GEMM (#7311) Summary - Replace uniform random instance sampling (random.shuffle) with scrambled Sobol + Latin Hypercube + maximin space-filling sampling, per the Tile Engine Benchmark Sampling RFC - Add op-weighted budget allocation via new TILE_ENGINE_SAMPLING_TIER=daily CMake knob that auto-distributes 8,000 instances across ops proportional to registered weights in op_weights.json - Emit chosen_instances.json manifests for reproducibility tracking - Consolidate 5 copies of sampling logic into single _apply_sampling() method on the base class Jenkinsfile changes Replace per-op -D *_MAX_INSTANCES=250 with single -D TILE_ENGINE_SAMPLING_TIER=daily in gfx942/gfx950/gfx1201 stages. Budget auto-distributes (8000 total per GPU target). --------- Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
275629fe34
commit
c31fc4df52
11
Jenkinsfile
vendored
11
Jenkinsfile
vendored
@@ -1870,9 +1870,10 @@ pipeline {
|
||||
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
|
||||
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
|
||||
-D GROUPED_GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GROUPED_GEMM_LAYOUT="rcr;rrr;crr;ccr" .. && \
|
||||
-D GROUPED_GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-D TILE_ENGINE_SAMPLING_TIER=daily .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all benchmark_grouped_gemm_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json gemm_universal_results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.py . --problem-sizes "1024,1024,1024" --group-counts 8 --warmup 5 --repeat 5 --verbose --json grouped_gemm_results.json """
|
||||
@@ -1901,7 +1902,8 @@ pipeline {
|
||||
-D GEMM_MULTI_D_DATATYPE="fp16" \
|
||||
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
|
||||
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
|
||||
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
|
||||
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
|
||||
-D TILE_ENGINE_SAMPLING_TIER=daily .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
@@ -1927,7 +1929,8 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx1201" \
|
||||
-D GEMM_UNIVERSAL_DATATYPE="fp16" \
|
||||
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" .. && \
|
||||
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-D TILE_ENGINE_SAMPLING_TIER=daily .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
|
||||
@@ -6,6 +6,11 @@ include_directories(BEFORE
|
||||
${CMAKE_CURRENT_LIST_DIR}/ops
|
||||
)
|
||||
|
||||
set(TILE_ENGINE_SAMPLING_TIER "" CACHE STRING
|
||||
"Sampling tier: 'daily' (8000 budget) or integer budget (empty = no cap)")
|
||||
set(TILE_ENGINE_SAMPLING_SEED "" CACHE STRING
|
||||
"Override sampling seed (empty = daily rotation)")
|
||||
|
||||
add_subdirectory(ops/fmha EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ops/gemm EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ops/gemm_streamk EXCLUDE_FROM_ALL)
|
||||
|
||||
@@ -1,7 +1,68 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Budget allocation when TILE_ENGINE_SAMPLING_TIER is set
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
|
||||
# Map tier name to budget
|
||||
if("${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "daily")
|
||||
set(_te_budget 8000)
|
||||
elseif("${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "weekly")
|
||||
set(_te_budget 0)
|
||||
else()
|
||||
set(_te_budget ${TILE_ENGINE_SAMPLING_TIER})
|
||||
endif()
|
||||
|
||||
if(_te_budget GREATER 0)
|
||||
# Detect active ops from their DATATYPE variables
|
||||
set(_active_ops "")
|
||||
foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm)
|
||||
string(TOUPPER ${_op} _OP_UPPER)
|
||||
if(NOT "${${_OP_UPPER}_DATATYPE}" STREQUAL "")
|
||||
list(APPEND _active_ops ${_op})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(_active_ops)
|
||||
string(REPLACE ";" "," _active_ops_csv "${_active_ops}")
|
||||
set(_alloc_dir "${CMAKE_CURRENT_BINARY_DIR}/sampling_alloc")
|
||||
file(MAKE_DIRECTORY ${_alloc_dir})
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
"PYTHONPATH=${CMAKE_CURRENT_LIST_DIR}/../../"
|
||||
${Python3_EXECUTABLE} -m sampling.allocate_budget
|
||||
--total-budget ${_te_budget}
|
||||
--active-ops "${_active_ops_csv}"
|
||||
--output-dir ${_alloc_dir}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/../../
|
||||
RESULT_VARIABLE _alloc_ret
|
||||
OUTPUT_VARIABLE _alloc_output
|
||||
ERROR_VARIABLE _alloc_error
|
||||
)
|
||||
if(NOT _alloc_ret EQUAL 0)
|
||||
message(FATAL_ERROR "Budget allocation failed: ${_alloc_error}")
|
||||
endif()
|
||||
message(STATUS "Sampling budget allocation:\n${_alloc_output}")
|
||||
|
||||
# Read per-op allocations (only if not already overridden)
|
||||
foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm)
|
||||
string(TOUPPER ${_op} _OP_UPPER)
|
||||
if("${${_OP_UPPER}_MAX_INSTANCES}" STREQUAL "")
|
||||
if(EXISTS "${_alloc_dir}/${_op}_budget.txt")
|
||||
file(READ "${_alloc_dir}/${_op}_budget.txt" _budget_val)
|
||||
string(STRIP "${_budget_val}" _budget_val)
|
||||
set(${_OP_UPPER}_MAX_INSTANCES ${_budget_val})
|
||||
message(STATUS " ${_op}: ${_budget_val} instances (from budget allocation)")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS " ${_op}: ${${_OP_UPPER}_MAX_INSTANCES} instances (explicit override)")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(gemm_universal EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(gemm_multi_d EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(gemm_preshuffle EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(grouped_gemm EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(grouped_gemm EXCLUDE_FROM_ALL)
|
||||
|
||||
@@ -43,6 +43,10 @@ class GemmKernelBuilder:
|
||||
datatype,
|
||||
layout,
|
||||
config_json=None,
|
||||
max_instances=None,
|
||||
seed=None,
|
||||
tier=None,
|
||||
manifest_path=None,
|
||||
):
|
||||
self.kernel_name_prefix = kernel_name_prefix
|
||||
self.working_path = Path(working_path)
|
||||
@@ -50,6 +54,10 @@ class GemmKernelBuilder:
|
||||
self.datatype = datatype
|
||||
self.layout = layout
|
||||
self.config_json = config_json
|
||||
self.max_instances = max_instances
|
||||
self.seed = seed
|
||||
self.tier = tier
|
||||
self.manifest_path = manifest_path
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
self.working_path.mkdir(parents=True, exist_ok=True)
|
||||
@@ -59,6 +67,74 @@ class GemmKernelBuilder:
|
||||
with open(config_json, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def _apply_sampling(self, kernel_list):
|
||||
"""Apply RFC Sobol+LHS+maximin sampling. Returns sampled subset."""
|
||||
if self.max_instances is None or len(kernel_list) <= self.max_instances:
|
||||
return kernel_list
|
||||
|
||||
import sys
|
||||
|
||||
sampling_parent = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
if sampling_parent not in sys.path:
|
||||
sys.path.insert(0, sampling_parent)
|
||||
|
||||
from sampling.sampler import sample_feasible_set
|
||||
from sampling.seed import make_seed
|
||||
from sampling.feasible_set import GEMM_AXES
|
||||
|
||||
effective_seed = make_seed(
|
||||
self.seed, self.gpu_target, self.datatype, self.layout
|
||||
)
|
||||
|
||||
flat_items = []
|
||||
for k in kernel_list:
|
||||
flat = dict(k["tile_config"])
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = k[
|
||||
"trait_combo"
|
||||
]
|
||||
flat.update(
|
||||
{
|
||||
"pipeline": pipeline,
|
||||
"epilogue": epilogue,
|
||||
"scheduler": scheduler,
|
||||
"pad_m": pad_m,
|
||||
"pad_n": pad_n,
|
||||
"pad_k": pad_k,
|
||||
"persistent": persistent,
|
||||
}
|
||||
)
|
||||
flat_items.append(flat)
|
||||
|
||||
selected, method, selected_indices = sample_feasible_set(
|
||||
flat_items,
|
||||
self.max_instances,
|
||||
effective_seed,
|
||||
GEMM_AXES,
|
||||
)
|
||||
|
||||
kernel_list = [kernel_list[i] for i in selected_indices]
|
||||
|
||||
if self.manifest_path:
|
||||
from sampling.manifest import write_manifest
|
||||
|
||||
write_manifest(
|
||||
selected,
|
||||
self.manifest_path,
|
||||
self.kernel_name_prefix,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.gpu_target,
|
||||
effective_seed,
|
||||
self.tier or "daily",
|
||||
method,
|
||||
)
|
||||
|
||||
print(
|
||||
f"Sampled {len(kernel_list)} from feasible set "
|
||||
f"(budget={self.max_instances}, seed={effective_seed}, method={method})"
|
||||
)
|
||||
return kernel_list
|
||||
|
||||
def _list_kernels(self):
|
||||
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
||||
# Get configurations using comprehensive validation
|
||||
@@ -96,6 +172,9 @@ class GemmKernelBuilder:
|
||||
}
|
||||
)
|
||||
|
||||
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
|
||||
kernel_list = self._apply_sampling(kernel_list)
|
||||
|
||||
# Write kernel count
|
||||
with open(
|
||||
self.working_path / f"{self.kernel_name_prefix}_kernel_count.txt", "w"
|
||||
|
||||
@@ -5,6 +5,7 @@ set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi
|
||||
set(GEMM_MULTI_D_LAYOUT "rcrr;rrrr;crrr;ccrr" CACHE STRING "List of layout for GEMM Multi D (semicolon-separated)")
|
||||
set(GEMM_MULTI_D_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
|
||||
set(GEMM_MULTI_D_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
|
||||
|
||||
option(ENABLE_CCACHE_GEMM_MULTI_D "Enable ccache for GEMM Multi D ops compilation" OFF)
|
||||
|
||||
@@ -175,6 +176,19 @@ function(build_individual_gemm_multi_d_targets datatype layout)
|
||||
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
# Build optional args for instance builder
|
||||
set(extra_list_args "")
|
||||
if(NOT "${GEMM_MULTI_D_MAX_INSTANCES}" STREQUAL "")
|
||||
list(APPEND extra_list_args --max-instances ${GEMM_MULTI_D_MAX_INSTANCES})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
|
||||
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
|
||||
list(APPEND extra_list_args --manifest-path ${working_path})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
|
||||
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
|
||||
endif()
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
@@ -185,7 +199,8 @@ function(build_individual_gemm_multi_d_targets datatype layout)
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
--list_kernels
|
||||
${extra_list_args}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
@@ -305,6 +320,20 @@ else()
|
||||
add_custom_target(benchmark_gemm_multi_d_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that
|
||||
# sampling fires per-combo rather than being a single cap larger than any combo's
|
||||
# feasible set (which would make sampling a no-op for most combos).
|
||||
if(NOT "${GEMM_MULTI_D_MAX_INSTANCES}" STREQUAL "")
|
||||
list(LENGTH GEMM_MULTI_D_DATATYPE _gmd_n_dt)
|
||||
list(LENGTH GEMM_MULTI_D_LAYOUT _gmd_n_lay)
|
||||
math(EXPR _gmd_n_combos "${_gmd_n_dt} * ${_gmd_n_lay}")
|
||||
if(_gmd_n_combos GREATER 0)
|
||||
math(EXPR GEMM_MULTI_D_MAX_INSTANCES
|
||||
"${GEMM_MULTI_D_MAX_INSTANCES} / ${_gmd_n_combos}")
|
||||
message(STATUS " gemm_multi_d: per-combo budget = ${GEMM_MULTI_D_MAX_INSTANCES} (${_gmd_n_combos} combos)")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
|
||||
|
||||
@@ -37,9 +37,22 @@ class GemmMultiDKernelBuilder(GemmKernelBuilder):
|
||||
layout,
|
||||
elementwise_function,
|
||||
config_json=None,
|
||||
max_instances=None,
|
||||
seed=None,
|
||||
tier=None,
|
||||
manifest_path=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
max_instances=max_instances,
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
manifest_path=manifest_path,
|
||||
)
|
||||
self.elementwise_function = elementwise_function
|
||||
|
||||
@@ -71,6 +84,15 @@ class GemmMultiDKernelBuilder(GemmKernelBuilder):
|
||||
)
|
||||
)
|
||||
|
||||
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
|
||||
if self.max_instances is not None and len(work_items) > self.max_instances:
|
||||
kernel_dicts = [
|
||||
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
|
||||
for item in work_items
|
||||
]
|
||||
sampled = self._apply_sampling(kernel_dicts)
|
||||
work_items = [k["_work_item"] for k in sampled]
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
@@ -223,6 +245,28 @@ def main():
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-instances",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Cap on number of kernel instances per (dtype, layout) combo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tier",
|
||||
default=None,
|
||||
help="Sampling tier (daily/weekly)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifest-path",
|
||||
default=None,
|
||||
help="Directory for chosen_instances.json",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -271,6 +315,10 @@ def main():
|
||||
args.layout,
|
||||
args.elementwise_function,
|
||||
args.config_json,
|
||||
max_instances=args.max_instances,
|
||||
seed=args.seed,
|
||||
tier=args.tier,
|
||||
manifest_path=args.manifest_path,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
set(GEMM_PRESHUFFLE_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
|
||||
option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
@@ -163,6 +164,19 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# Build optional args for instance builder
|
||||
set(extra_list_args "")
|
||||
if(NOT "${GEMM_PRESHUFFLE_MAX_INSTANCES}" STREQUAL "")
|
||||
list(APPEND extra_list_args --max-instances ${GEMM_PRESHUFFLE_MAX_INSTANCES})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
|
||||
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
|
||||
list(APPEND extra_list_args --manifest-path ${working_path})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
|
||||
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
|
||||
endif()
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
message(VERBOSE " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
@@ -174,6 +188,7 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
${extra_list_args}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
@@ -293,6 +308,20 @@ else()
|
||||
add_custom_target(benchmark_gemm_preshuffle_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that
|
||||
# sampling fires per-combo rather than being a single cap larger than any combo's
|
||||
# feasible set (which would make sampling a no-op for most combos).
|
||||
if(NOT "${GEMM_PRESHUFFLE_MAX_INSTANCES}" STREQUAL "")
|
||||
list(LENGTH GEMM_PRESHUFFLE_DATATYPE _gp_n_dt)
|
||||
list(LENGTH GEMM_PRESHUFFLE_LAYOUT _gp_n_lay)
|
||||
math(EXPR _gp_n_combos "${_gp_n_dt} * ${_gp_n_lay}")
|
||||
if(_gp_n_combos GREATER 0)
|
||||
math(EXPR GEMM_PRESHUFFLE_MAX_INSTANCES
|
||||
"${GEMM_PRESHUFFLE_MAX_INSTANCES} / ${_gp_n_combos}")
|
||||
message(STATUS " gemm_preshuffle: per-combo budget = ${GEMM_PRESHUFFLE_MAX_INSTANCES} (${_gp_n_combos} combos)")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
|
||||
|
||||
@@ -36,9 +36,22 @@ class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
|
||||
datatype,
|
||||
layout,
|
||||
config_json=None,
|
||||
max_instances=None,
|
||||
seed=None,
|
||||
tier=None,
|
||||
manifest_path=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
max_instances=max_instances,
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
manifest_path=manifest_path,
|
||||
)
|
||||
|
||||
def _generate_all_individual(self, num_workers=None):
|
||||
@@ -68,6 +81,15 @@ class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
|
||||
)
|
||||
)
|
||||
|
||||
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
|
||||
if self.max_instances is not None and len(work_items) > self.max_instances:
|
||||
kernel_dicts = [
|
||||
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
|
||||
for item in work_items
|
||||
]
|
||||
sampled = self._apply_sampling(kernel_dicts)
|
||||
work_items = [k["_work_item"] for k in sampled]
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
@@ -212,6 +234,28 @@ def main():
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-instances",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Cap on number of kernel instances per (dtype, layout) combo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tier",
|
||||
default=None,
|
||||
help="Sampling tier (daily/weekly)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifest-path",
|
||||
default=None,
|
||||
help="Directory for chosen_instances.json",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -239,6 +283,10 @@ def main():
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.config_json,
|
||||
max_instances=args.max_instances,
|
||||
seed=args.seed,
|
||||
tier=args.tier,
|
||||
manifest_path=args.manifest_path,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)")
|
||||
set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)")
|
||||
set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
set(GEMM_UNIVERSAL_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
|
||||
option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
@@ -171,6 +172,19 @@ function(build_individual_gemm_universal_targets datatype layout)
|
||||
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
# Build optional args for instance builder
|
||||
set(extra_list_args "")
|
||||
if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "")
|
||||
list(APPEND extra_list_args --max-instances ${GEMM_UNIVERSAL_MAX_INSTANCES})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
|
||||
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
|
||||
list(APPEND extra_list_args --manifest-path ${working_path})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
|
||||
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
|
||||
endif()
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
@@ -180,7 +194,8 @@ function(build_individual_gemm_universal_targets datatype layout)
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
--list_kernels
|
||||
${extra_list_args}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
@@ -300,6 +315,20 @@ else()
|
||||
add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that
|
||||
# sampling fires per-combo rather than being a single cap larger than any combo's
|
||||
# feasible set (which would make sampling a no-op for most combos).
|
||||
if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "")
|
||||
list(LENGTH GEMM_UNIVERSAL_DATATYPE _gu_n_dt)
|
||||
list(LENGTH GEMM_UNIVERSAL_LAYOUT _gu_n_lay)
|
||||
math(EXPR _gu_n_combos "${_gu_n_dt} * ${_gu_n_lay}")
|
||||
if(_gu_n_combos GREATER 0)
|
||||
math(EXPR GEMM_UNIVERSAL_MAX_INSTANCES
|
||||
"${GEMM_UNIVERSAL_MAX_INSTANCES} / ${_gu_n_combos}")
|
||||
message(STATUS " gemm_universal: per-combo budget = ${GEMM_UNIVERSAL_MAX_INSTANCES} (${_gu_n_combos} combos)")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
|
||||
|
||||
@@ -36,9 +36,22 @@ class GemmUniversalKernelBuilder(GemmKernelBuilder):
|
||||
datatype,
|
||||
layout,
|
||||
config_json=None,
|
||||
max_instances=None,
|
||||
seed=None,
|
||||
tier=None,
|
||||
manifest_path=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
max_instances=max_instances,
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
manifest_path=manifest_path,
|
||||
)
|
||||
|
||||
def _generate_all_individual(self, num_workers=None):
|
||||
@@ -67,6 +80,16 @@ class GemmUniversalKernelBuilder(GemmKernelBuilder):
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
|
||||
if self.max_instances is not None and len(work_items) > self.max_instances:
|
||||
kernel_dicts = [
|
||||
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
|
||||
for item in work_items
|
||||
]
|
||||
sampled = self._apply_sampling(kernel_dicts)
|
||||
work_items = [k["_work_item"] for k in sampled]
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
@@ -210,6 +233,28 @@ def main():
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-instances",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Cap on number of kernel instances per (dtype, layout) combo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tier",
|
||||
default=None,
|
||||
help="Sampling tier (daily/weekly)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifest-path",
|
||||
default=None,
|
||||
help="Directory for chosen_instances.json",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -236,6 +281,10 @@ def main():
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.config_json,
|
||||
max_instances=args.max_instances,
|
||||
seed=args.seed,
|
||||
tier=args.tier,
|
||||
manifest_path=args.manifest_path,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
set(GROUPED_GEMM_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for Grouped GEMM (semicolon-separated)")
|
||||
set(GROUPED_GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for Grouped GEMM (semicolon-separated)")
|
||||
set(GROUPED_GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
set(GROUPED_GEMM_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)")
|
||||
option(ENABLE_CCACHE_GROUPED_GEMM "Enable ccache for Grouped GEMM ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
@@ -171,6 +172,19 @@ function(build_individual_grouped_gemm_targets datatype layout)
|
||||
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
# Build optional args for instance builder
|
||||
set(extra_list_args "")
|
||||
if(NOT "${GROUPED_GEMM_MAX_INSTANCES}" STREQUAL "")
|
||||
list(APPEND extra_list_args --max-instances ${GROUPED_GEMM_MAX_INSTANCES})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "")
|
||||
list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER})
|
||||
list(APPEND extra_list_args --manifest-path ${working_path})
|
||||
endif()
|
||||
if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "")
|
||||
list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED})
|
||||
endif()
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
@@ -181,6 +195,7 @@ function(build_individual_grouped_gemm_targets datatype layout)
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
${extra_list_args}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
|
||||
@@ -36,9 +36,22 @@ class GroupedGemmKernelBuilder(GemmKernelBuilder):
|
||||
datatype,
|
||||
layout,
|
||||
config_json=None,
|
||||
max_instances=None,
|
||||
seed=None,
|
||||
tier=None,
|
||||
manifest_path=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
max_instances=max_instances,
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
manifest_path=manifest_path,
|
||||
)
|
||||
|
||||
def _generate_all_individual(self, num_workers=None):
|
||||
@@ -67,6 +80,16 @@ class GroupedGemmKernelBuilder(GemmKernelBuilder):
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply RFC-compliant sampling (Sobol + LHS + maximin)
|
||||
if self.max_instances is not None and len(work_items) > self.max_instances:
|
||||
kernel_dicts = [
|
||||
{"tile_config": item[0], "trait_combo": item[1], "_work_item": item}
|
||||
for item in work_items
|
||||
]
|
||||
sampled = self._apply_sampling(kernel_dicts)
|
||||
work_items = [k["_work_item"] for k in sampled]
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
@@ -209,6 +232,28 @@ def main():
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-instances",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Cap on number of kernel instances per (dtype, layout) combo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="RNG seed for deterministic sampling; if omitted, derived from today's date",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tier",
|
||||
default=None,
|
||||
help="Sampling tier (daily/weekly)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifest-path",
|
||||
default=None,
|
||||
help="Directory for chosen_instances.json",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -244,6 +289,10 @@ def main():
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.config_json,
|
||||
max_instances=args.max_instances,
|
||||
seed=args.seed,
|
||||
tier=args.tier,
|
||||
manifest_path=args.manifest_path,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
|
||||
10
tile_engine/sampling/__init__.py
Normal file
10
tile_engine/sampling/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from sampling.sampler import sample_feasible_set as sample_feasible_set
|
||||
from sampling.seed import make_seed as make_seed
|
||||
from sampling.budget import allocate_budget as allocate_budget
|
||||
from sampling.budget import load_op_weights as load_op_weights
|
||||
from sampling.manifest import write_manifest as write_manifest
|
||||
from sampling.feasible_set import GEMM_AXES as GEMM_AXES
|
||||
from sampling.feasible_set import normalize_axis_values as normalize_axis_values
|
||||
102
tile_engine/sampling/allocate_budget.py
Normal file
102
tile_engine/sampling/allocate_budget.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""CLI entry point for budget allocation, called by CMake at configure time.
|
||||
|
||||
Usage:
|
||||
python -m sampling.allocate_budget \
|
||||
--total-budget 8000 \
|
||||
--active-ops "gemm_universal,gemm_multi_d,gemm_preshuffle,grouped_gemm" \
|
||||
--output-dir /build/sampling_alloc \
|
||||
[--weights-file /path/to/op_weights.json]
|
||||
|
||||
Writes per-op budget files (e.g. gemm_universal_budget.txt) containing a single integer.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _setup_path():
|
||||
_this_dir = Path(__file__).resolve().parent
|
||||
_tile_engine_dir = _this_dir.parent
|
||||
if str(_tile_engine_dir) not in sys.path:
|
||||
sys.path.insert(0, str(_tile_engine_dir))
|
||||
|
||||
|
||||
_setup_path()
|
||||
|
||||
from sampling.budget import allocate_budget # noqa: E402
|
||||
from sampling.budget import load_op_weights # noqa: E402
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Allocate instance budget across ops")
|
||||
parser.add_argument(
|
||||
"--total-budget",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Total instance budget (e.g. 8000 for daily tier)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--active-ops",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Comma or semicolon-separated list of active op names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory to write per-op budget files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weights-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to op_weights.json (default: built-in)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse active ops (support both comma and semicolon separators)
|
||||
active_ops = [
|
||||
op.strip() for op in args.active_ops.replace(";", ",").split(",") if op.strip()
|
||||
]
|
||||
|
||||
if not active_ops:
|
||||
print("ERROR: No active ops specified", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
weights = load_op_weights(args.weights_file)
|
||||
alloc = allocate_budget(args.total_budget, active_ops, weights, strict=True)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write per-op budget files
|
||||
for op, budget in alloc.items():
|
||||
budget_file = output_dir / f"{op}_budget.txt"
|
||||
budget_file.write_text(str(budget))
|
||||
|
||||
# Write combined allocation metadata
|
||||
meta = {
|
||||
"total_budget": args.total_budget,
|
||||
"active_ops": active_ops,
|
||||
"allocations": alloc,
|
||||
"weights_used": {op: weights.get(op, 0.0) for op in active_ops},
|
||||
}
|
||||
meta_file = output_dir / "sampling_allocations.json"
|
||||
with open(meta_file, "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Print summary
|
||||
print(f"Budget allocation (total={args.total_budget}):")
|
||||
for op, budget in sorted(alloc.items()):
|
||||
print(f" {op}: {budget}")
|
||||
print(f" Sum: {sum(alloc.values())}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
78
tile_engine/sampling/budget.py
Normal file
78
tile_engine/sampling/budget.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Op-weighted budget allocation.
|
||||
|
||||
Distributes a total instance budget across active ops proportional to
|
||||
their registered weights. Implements RFC section 3.4.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_DEFAULT_WEIGHTS_FILE = Path(__file__).parent / "op_weights.json"
|
||||
|
||||
|
||||
def load_op_weights(weights_file=None):
|
||||
"""Load op weights from JSON file.
|
||||
|
||||
Returns:
|
||||
Dict mapping op name to weight (float).
|
||||
"""
|
||||
path = Path(weights_file) if weights_file else _DEFAULT_WEIGHTS_FILE
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
return data["weights"]
|
||||
|
||||
|
||||
def allocate_budget(total_budget, active_ops, weights, strict=True):
|
||||
"""Distribute total_budget across active_ops proportional to weights.
|
||||
|
||||
Args:
|
||||
total_budget: Total instance budget (e.g. 8000).
|
||||
active_ops: List of active op names.
|
||||
weights: Dict mapping op name to weight.
|
||||
strict: If True, raise ValueError for unweighted active ops.
|
||||
|
||||
Returns:
|
||||
Dict mapping op name to allocated budget (int).
|
||||
Sum of allocations exactly equals total_budget.
|
||||
"""
|
||||
if not active_ops:
|
||||
return {}
|
||||
|
||||
# Check all active ops have weights
|
||||
missing = [op for op in active_ops if op not in weights]
|
||||
if missing and strict:
|
||||
raise ValueError(
|
||||
f"Active ops without registered weights: {missing}. "
|
||||
f"Add them to op_weights.json before running with sampling enabled."
|
||||
)
|
||||
|
||||
# Compute weight sum for active ops only
|
||||
active_weights = {op: weights.get(op, 0.0) for op in active_ops}
|
||||
total_weight = sum(active_weights.values())
|
||||
|
||||
if total_weight <= 0:
|
||||
# Equal distribution fallback
|
||||
per_op = total_budget // len(active_ops)
|
||||
alloc = {op: per_op for op in active_ops}
|
||||
remainder = total_budget - sum(alloc.values())
|
||||
for i, op in enumerate(active_ops):
|
||||
if i < remainder:
|
||||
alloc[op] += 1
|
||||
return alloc
|
||||
|
||||
# Proportional allocation with floor
|
||||
alloc = {}
|
||||
for op in active_ops:
|
||||
alloc[op] = int(total_budget * active_weights[op] / total_weight)
|
||||
|
||||
# Distribute remainder to highest-weight ops
|
||||
remainder = total_budget - sum(alloc.values())
|
||||
sorted_ops = sorted(active_ops, key=lambda op: active_weights[op], reverse=True)
|
||||
for i in range(remainder):
|
||||
alloc[sorted_ops[i % len(sorted_ops)]] += 1
|
||||
|
||||
return alloc
|
||||
82
tile_engine/sampling/feasible_set.py
Normal file
82
tile_engine/sampling/feasible_set.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
GEMM_AXES = [
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"epilogue",
|
||||
"scheduler",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
]
|
||||
|
||||
CATEGORICAL_AXES = {
|
||||
"pipeline",
|
||||
"epilogue",
|
||||
"scheduler",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
}
|
||||
|
||||
|
||||
def normalize_axis_values(feasible_set, axes=None):
|
||||
"""Compute normalization metadata for each axis.
|
||||
|
||||
Returns dict mapping axis name to:
|
||||
- For numeric axes: {"type": "numeric", "min": v, "max": v, "range": v}
|
||||
- For categorical axes: {"type": "categorical", "values": sorted list, "map": value->index}
|
||||
"""
|
||||
if axes is None:
|
||||
axes = GEMM_AXES
|
||||
|
||||
meta = {}
|
||||
for ax in axes:
|
||||
values = [item[ax] for item in feasible_set if ax in item]
|
||||
if not values:
|
||||
continue
|
||||
|
||||
if ax in CATEGORICAL_AXES:
|
||||
unique = sorted(set(str(v) for v in values))
|
||||
meta[ax] = {
|
||||
"type": "categorical",
|
||||
"values": unique,
|
||||
"map": {v: i for i, v in enumerate(unique)},
|
||||
"count": len(unique),
|
||||
}
|
||||
else:
|
||||
num_values = [float(v) for v in values]
|
||||
mn, mx = min(num_values), max(num_values)
|
||||
meta[ax] = {
|
||||
"type": "numeric",
|
||||
"min": mn,
|
||||
"max": mx,
|
||||
"range": mx - mn if mx != mn else 1.0,
|
||||
}
|
||||
return meta
|
||||
|
||||
|
||||
def normalize_point(item, axes, meta):
|
||||
"""Normalize a single point to [0, 1] per axis."""
|
||||
coords = []
|
||||
for ax in axes:
|
||||
if ax not in meta or ax not in item:
|
||||
coords.append(0.0)
|
||||
continue
|
||||
m = meta[ax]
|
||||
if m["type"] == "numeric":
|
||||
coords.append((float(item[ax]) - m["min"]) / m["range"])
|
||||
else:
|
||||
coords.append(m["map"].get(str(item[ax]), 0) / max(m["count"] - 1, 1))
|
||||
return coords
|
||||
108
tile_engine/sampling/lhs.py
Normal file
108
tile_engine/sampling/lhs.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Latin Hypercube Sampling padding for marginal coverage.
|
||||
|
||||
Ensures every distinct value on every parameter axis appears at least once
|
||||
in the selected sample, using a greedy set-cover heuristic.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def lhs_pad(selected_indices, feasible_set, axes, budget_remaining, rng):
|
||||
"""Add indices to guarantee marginal coverage on all axes.
|
||||
|
||||
Args:
|
||||
selected_indices: Already-selected feasible-set indices.
|
||||
feasible_set: Full list of parameter dicts.
|
||||
axes: List of axis names to ensure coverage for.
|
||||
budget_remaining: Max additional indices to add.
|
||||
rng: random.Random instance for tie-breaking.
|
||||
|
||||
Returns:
|
||||
List of additional indices to include.
|
||||
"""
|
||||
if budget_remaining <= 0:
|
||||
return []
|
||||
|
||||
selected_set = set(selected_indices)
|
||||
|
||||
# Build per-axis coverage maps: axis -> value -> set of feasible indices with that value
|
||||
axis_value_indices = {}
|
||||
for ax in axes:
|
||||
value_map = defaultdict(set)
|
||||
for i, item in enumerate(feasible_set):
|
||||
if ax in item:
|
||||
value_map[str(item[ax])].add(i)
|
||||
axis_value_indices[ax] = value_map
|
||||
|
||||
# Find which axis values are already covered
|
||||
covered = {}
|
||||
for ax in axes:
|
||||
covered[ax] = set()
|
||||
for idx in selected_set:
|
||||
if ax in feasible_set[idx]:
|
||||
covered[ax].add(str(feasible_set[idx][ax]))
|
||||
|
||||
# Find uncovered axis values
|
||||
uncovered_pairs = [] # (axis, value) pairs not yet covered
|
||||
for ax in axes:
|
||||
for val in axis_value_indices[ax]:
|
||||
if val not in covered[ax]:
|
||||
uncovered_pairs.append((ax, val))
|
||||
|
||||
if not uncovered_pairs:
|
||||
return []
|
||||
|
||||
# Greedy set-cover: pick indices that cover the most uncovered (axis, value) pairs
|
||||
additional = []
|
||||
uncovered_set = set(range(len(uncovered_pairs)))
|
||||
|
||||
while uncovered_set and len(additional) < budget_remaining:
|
||||
# For each candidate index, count how many uncovered pairs it covers
|
||||
best_idx = -1
|
||||
best_count = 0
|
||||
best_covers = set()
|
||||
|
||||
# Build candidate pool: indices that appear in at least one uncovered pair's index set
|
||||
candidates = set()
|
||||
for ui in uncovered_set:
|
||||
ax, val = uncovered_pairs[ui]
|
||||
candidates.update(axis_value_indices[ax][val])
|
||||
candidates -= selected_set
|
||||
candidates -= set(additional)
|
||||
|
||||
if not candidates:
|
||||
break
|
||||
|
||||
# Sample a subset to avoid O(N*U) when both are large
|
||||
candidate_list = list(candidates)
|
||||
if len(candidate_list) > 500:
|
||||
rng.shuffle(candidate_list)
|
||||
candidate_list = candidate_list[:500]
|
||||
|
||||
for ci in candidate_list:
|
||||
item = feasible_set[ci]
|
||||
covers = set()
|
||||
for ui in uncovered_set:
|
||||
ax, val = uncovered_pairs[ui]
|
||||
if ax in item and str(item[ax]) == val:
|
||||
covers.add(ui)
|
||||
if len(covers) > best_count:
|
||||
best_count = len(covers)
|
||||
best_idx = ci
|
||||
best_covers = covers
|
||||
|
||||
if best_idx < 0:
|
||||
break
|
||||
|
||||
additional.append(best_idx)
|
||||
uncovered_set -= best_covers
|
||||
# Update covered sets
|
||||
item = feasible_set[best_idx]
|
||||
for ax in axes:
|
||||
if ax in item:
|
||||
covered[ax].add(str(item[ax]))
|
||||
|
||||
return additional
|
||||
124
tile_engine/sampling/manifest.py
Normal file
124
tile_engine/sampling/manifest.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Manifest writer for chosen_instances.json.
|
||||
|
||||
Each tier run emits a manifest recording the selected instances, their
|
||||
parameters, the sampling method, seed, and metadata for reproducibility.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _instance_hash(params):
|
||||
"""Compute a 16-hex-char fingerprint of tile+trait parameters."""
|
||||
canonical = json.dumps(params, sort_keys=True, default=str)
|
||||
return hashlib.sha256(canonical.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
def _get_git_sha():
|
||||
"""Get current git HEAD SHA, or 'unknown' if not in a git repo."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def write_manifest(
|
||||
instances, output_path, op, datatype, layout, gpu_target, seed, tier, sampler_method
|
||||
):
|
||||
"""Write chosen_instances.json manifest.
|
||||
|
||||
Args:
|
||||
instances: List of parameter dicts (flat tile+trait keys).
|
||||
output_path: Directory to write the manifest into.
|
||||
op: Op name (e.g. "gemm_universal").
|
||||
datatype: Data type string (e.g. "fp16").
|
||||
layout: Layout string (e.g. "rcr").
|
||||
gpu_target: GPU target (e.g. "gfx942").
|
||||
seed: Integer seed used for sampling.
|
||||
tier: Tier name (e.g. "daily").
|
||||
sampler_method: Sampling method string (e.g. "sobol+lhs+maximin").
|
||||
|
||||
Returns:
|
||||
Path to the written manifest file.
|
||||
"""
|
||||
output_dir = Path(output_path)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
git_sha = _get_git_sha()
|
||||
|
||||
manifest_entries = []
|
||||
for params in instances:
|
||||
entry = {
|
||||
"instance_hash": _instance_hash(params),
|
||||
"op": op,
|
||||
"dtype": datatype,
|
||||
"layout": layout,
|
||||
"arch": gpu_target,
|
||||
}
|
||||
# Add tile parameters
|
||||
for key in [
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
]:
|
||||
if key in params:
|
||||
entry[key] = params[key]
|
||||
|
||||
# Add trait parameters
|
||||
for key in [
|
||||
"pipeline",
|
||||
"epilogue",
|
||||
"scheduler",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
]:
|
||||
if key in params:
|
||||
entry[key] = params[key]
|
||||
|
||||
entry["sampler_method"] = sampler_method
|
||||
entry["seed"] = seed
|
||||
entry["tier"] = tier
|
||||
entry["git_sha"] = git_sha
|
||||
|
||||
manifest_entries.append(entry)
|
||||
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"op": op,
|
||||
"dtype": datatype,
|
||||
"layout": layout,
|
||||
"arch": gpu_target,
|
||||
"seed": seed,
|
||||
"tier": tier,
|
||||
"sampler_method": sampler_method,
|
||||
"git_sha": git_sha,
|
||||
"instance_count": len(manifest_entries),
|
||||
"instances": manifest_entries,
|
||||
}
|
||||
|
||||
manifest_file = output_dir / f"chosen_instances_{op}_{datatype}_{layout}.json"
|
||||
with open(manifest_file, "w") as f:
|
||||
json.dump(manifest, f, indent=2, default=str)
|
||||
|
||||
return manifest_file
|
||||
135
tile_engine/sampling/maximin.py
Normal file
135
tile_engine/sampling/maximin.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Maximin simulated annealing post-pass.
|
||||
|
||||
Improves minimum pairwise distance in the selected subset by swapping
|
||||
points with unselected candidates. RFC specifies 200 iterations.
|
||||
|
||||
Uses a cached pairwise distance matrix to avoid O(n^2) recomputation
|
||||
per iteration. Updates are O(n) per swap.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def _manhattan_distance(a, b):
|
||||
return sum(abs(x - y) for x, y in zip(a, b))
|
||||
|
||||
|
||||
def maximin_anneal(
|
||||
selected_indices, feasible_set, normalized_coords, iterations=200, rng=None
|
||||
):
|
||||
"""Improve minimum pairwise distance via simulated annealing.
|
||||
|
||||
Args:
|
||||
selected_indices: List of indices into feasible_set (will be modified in-place).
|
||||
feasible_set: Full list of parameter dicts (not modified).
|
||||
normalized_coords: List of normalized coordinate vectors, one per feasible-set item.
|
||||
iterations: Number of SA iterations (default 200 per RFC).
|
||||
rng: random.Random instance.
|
||||
|
||||
Returns:
|
||||
Modified selected_indices list.
|
||||
"""
|
||||
import random as random_mod
|
||||
|
||||
if rng is None:
|
||||
rng = random_mod.Random(42)
|
||||
|
||||
n = len(selected_indices)
|
||||
if n < 3:
|
||||
return selected_indices
|
||||
|
||||
all_indices = set(range(len(feasible_set)))
|
||||
selected_set = set(selected_indices)
|
||||
unselected = list(all_indices - selected_set)
|
||||
|
||||
if not unselected:
|
||||
return selected_indices
|
||||
|
||||
sel_coords = [normalized_coords[i] for i in selected_indices]
|
||||
|
||||
# Build per-point minimum distance cache: for each point, store its
|
||||
# minimum distance to any other selected point and the index of that neighbor
|
||||
min_dists = [float("inf")] * n
|
||||
min_neighbors = [0] * n
|
||||
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
d = _manhattan_distance(sel_coords[i], sel_coords[j])
|
||||
if d < min_dists[i]:
|
||||
min_dists[i] = d
|
||||
min_neighbors[i] = j
|
||||
if d < min_dists[j]:
|
||||
min_dists[j] = d
|
||||
min_neighbors[j] = i
|
||||
|
||||
for iteration in range(iterations):
|
||||
t = 1.0 - (iteration / iterations) * 0.99
|
||||
|
||||
# Find the point with the globally smallest min_dist (half of closest pair)
|
||||
victim_pos = min(range(n), key=lambda i: min_dists[i])
|
||||
old_min_dist = min_dists[victim_pos]
|
||||
victim_idx = selected_indices[victim_pos]
|
||||
|
||||
# Pick a random unselected candidate
|
||||
candidate_pos = rng.randint(0, len(unselected) - 1)
|
||||
candidate_idx = unselected[candidate_pos]
|
||||
candidate_coord = normalized_coords[candidate_idx]
|
||||
|
||||
# Compute candidate's min distance to all other selected points
|
||||
new_cand_min = float("inf")
|
||||
for k in range(n):
|
||||
if k == victim_pos:
|
||||
continue
|
||||
d = _manhattan_distance(candidate_coord, sel_coords[k])
|
||||
if d < new_cand_min:
|
||||
new_cand_min = d
|
||||
|
||||
delta = new_cand_min - old_min_dist
|
||||
accept = delta > 0
|
||||
if not accept and t > 0.001:
|
||||
try:
|
||||
prob = math.exp(delta / t)
|
||||
accept = rng.random() < prob
|
||||
except (OverflowError, ValueError):
|
||||
accept = False
|
||||
|
||||
if accept:
|
||||
unselected[candidate_pos] = victim_idx
|
||||
selected_indices[victim_pos] = candidate_idx
|
||||
sel_coords[victim_pos] = candidate_coord
|
||||
|
||||
# Recompute min_dists for the swapped position and any point
|
||||
# whose nearest neighbor was the victim
|
||||
for k in range(n):
|
||||
if k == victim_pos:
|
||||
# Recompute for the new point
|
||||
min_dists[k] = float("inf")
|
||||
min_neighbors[k] = 0
|
||||
for j in range(n):
|
||||
if j == k:
|
||||
continue
|
||||
d = _manhattan_distance(sel_coords[k], sel_coords[j])
|
||||
if d < min_dists[k]:
|
||||
min_dists[k] = d
|
||||
min_neighbors[k] = j
|
||||
elif min_neighbors[k] == victim_pos:
|
||||
# Nearest neighbor was replaced — full recompute for this point
|
||||
min_dists[k] = float("inf")
|
||||
for j in range(n):
|
||||
if j == k:
|
||||
continue
|
||||
d = _manhattan_distance(sel_coords[k], sel_coords[j])
|
||||
if d < min_dists[k]:
|
||||
min_dists[k] = d
|
||||
min_neighbors[k] = j
|
||||
else:
|
||||
# Check if the new point is closer than current minimum
|
||||
d = _manhattan_distance(sel_coords[k], sel_coords[victim_pos])
|
||||
if d < min_dists[k]:
|
||||
min_dists[k] = d
|
||||
min_neighbors[k] = victim_pos
|
||||
|
||||
return selected_indices
|
||||
10
tile_engine/sampling/op_weights.json
Normal file
10
tile_engine/sampling/op_weights.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"description": "Op weights for Daily Tier budget allocation (RFC section 3.4)",
|
||||
"weights": {
|
||||
"gemm_universal": 0.35,
|
||||
"gemm_preshuffle": 0.30,
|
||||
"gemm_multi_d": 0.20,
|
||||
"grouped_gemm": 0.15
|
||||
}
|
||||
}
|
||||
103
tile_engine/sampling/sampler.py
Normal file
103
tile_engine/sampling/sampler.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Main sampling orchestrator: Sobol -> LHS pad -> maximin refine.
|
||||
|
||||
Implements the Daily Tier sampling pipeline from the RFC:
|
||||
1. Sobol base draw (90% of budget)
|
||||
2. LHS padding for marginal axis coverage (10% reserve)
|
||||
3. Maximin simulated annealing post-pass (200 iterations)
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from sampling.sobol import SobolSequence
|
||||
from sampling.lhs import lhs_pad
|
||||
from sampling.maximin import maximin_anneal
|
||||
from sampling.feasible_set import GEMM_AXES, normalize_axis_values, normalize_point
|
||||
|
||||
|
||||
def sample_feasible_set(feasible_set, budget, seed, axes=None, maximin_iterations=200):
|
||||
"""Select `budget` items from `feasible_set` using Sobol + LHS + maximin.
|
||||
|
||||
Args:
|
||||
feasible_set: List of parameter dicts (each with tile/trait keys).
|
||||
budget: Maximum number of items to select.
|
||||
seed: Integer seed for deterministic selection.
|
||||
axes: List of axis names (defaults to GEMM_AXES).
|
||||
maximin_iterations: SA iterations for maximin pass (default 200).
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_items, sampler_method_string).
|
||||
"""
|
||||
n = len(feasible_set)
|
||||
if axes is None:
|
||||
axes = GEMM_AXES
|
||||
|
||||
if budget >= n:
|
||||
return list(feasible_set), "full", list(range(n))
|
||||
|
||||
if n == 0:
|
||||
return [], "empty", []
|
||||
|
||||
rng = random.Random(seed)
|
||||
|
||||
# Phase 1: Sobol base selection (fill as much as possible from Sobol)
|
||||
sobol = SobolSequence(d=1, scramble=True, seed=seed)
|
||||
raw_points = sobol.generate(min(budget * 4, n * 2))
|
||||
|
||||
selected_indices = []
|
||||
seen = set()
|
||||
for pt in raw_points:
|
||||
idx = min(int(pt[0] * n), n - 1)
|
||||
if idx not in seen:
|
||||
seen.add(idx)
|
||||
selected_indices.append(idx)
|
||||
if len(selected_indices) >= budget:
|
||||
break
|
||||
|
||||
# If Sobol didn't produce enough unique points, fill with RNG
|
||||
if len(selected_indices) < budget:
|
||||
remaining = list(set(range(n)) - seen)
|
||||
rng.shuffle(remaining)
|
||||
for idx in remaining:
|
||||
if len(selected_indices) >= budget:
|
||||
break
|
||||
selected_indices.append(idx)
|
||||
seen.add(idx)
|
||||
|
||||
# Phase 2: LHS padding — swap in points that cover uncovered axis values
|
||||
available_axes = [ax for ax in axes if any(ax in item for item in feasible_set)]
|
||||
lhs_additions = lhs_pad(
|
||||
selected_indices,
|
||||
feasible_set,
|
||||
available_axes,
|
||||
max(0, budget - len(selected_indices)),
|
||||
rng,
|
||||
)
|
||||
for idx in lhs_additions:
|
||||
if idx not in seen:
|
||||
seen.add(idx)
|
||||
selected_indices.append(idx)
|
||||
|
||||
# Trim to budget
|
||||
if len(selected_indices) > budget:
|
||||
selected_indices = selected_indices[:budget]
|
||||
|
||||
# Phase 3: Maximin simulated annealing
|
||||
meta = normalize_axis_values(feasible_set, available_axes)
|
||||
all_coords = [normalize_point(item, available_axes, meta) for item in feasible_set]
|
||||
|
||||
selected_indices = maximin_anneal(
|
||||
selected_indices,
|
||||
feasible_set,
|
||||
all_coords,
|
||||
iterations=maximin_iterations,
|
||||
rng=rng,
|
||||
)
|
||||
|
||||
# Sort by original index for deterministic output order
|
||||
selected_indices.sort()
|
||||
|
||||
selected_items = [feasible_set[i] for i in selected_indices]
|
||||
return selected_items, "sobol+lhs+maximin", selected_indices
|
||||
24
tile_engine/sampling/seed.py
Normal file
24
tile_engine/sampling/seed.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import hashlib
|
||||
import datetime
|
||||
|
||||
|
||||
def daily_seed(date=None, extra=""):
|
||||
"""Return sha256(YYYY-MM-DD[:extra]) & 0xFFFFFFFF."""
|
||||
if date is None:
|
||||
date = datetime.date.today()
|
||||
material = date.isoformat()
|
||||
if extra:
|
||||
material += f":{extra}"
|
||||
return int(hashlib.sha256(material.encode()).hexdigest(), 16) & 0xFFFFFFFF
|
||||
|
||||
|
||||
def make_seed(explicit_seed=None, gpu_target="", datatype="", layout=""):
|
||||
"""If explicit_seed is given, return it. Otherwise compute daily seed
|
||||
with gpu_target:datatype:layout as extra material."""
|
||||
if explicit_seed is not None:
|
||||
return explicit_seed
|
||||
extra = ":".join(filter(None, [gpu_target, datatype, layout]))
|
||||
return daily_seed(extra=extra)
|
||||
136
tile_engine/sampling/sobol.py
Normal file
136
tile_engine/sampling/sobol.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Pure-Python scrambled Sobol sequence generator with scipy fallback.
|
||||
|
||||
Uses Joe-Kuo direction numbers for up to 21 dimensions. At daily-tier
|
||||
budgets (~2000-3000 points from ~10,000-25,000 feasible), pure Python
|
||||
runs well under a second.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
# Joe-Kuo direction numbers for dimensions 2-21.
|
||||
# Each row: (degree_of_primitive_poly, coefficients_of_poly, [initial_direction_numbers])
|
||||
# Dimension 1 uses the Van der Corput sequence (bit-reversal).
|
||||
# Source: Joe & Kuo (2010), https://web.maths.unsw.edu.au/~fkuo/sobol/joe-kuo-old.1111
|
||||
_JOE_KUO_PARAMS = [
|
||||
(1, 0, [1]),
|
||||
(2, 1, [1, 1]),
|
||||
(3, 1, [1, 3, 1]),
|
||||
(3, 2, [1, 3, 3]),
|
||||
(4, 1, [1, 1, 1, 1]), # dim 6
|
||||
(4, 4, [1, 1, 3, 3]),
|
||||
(5, 2, [1, 3, 5, 13, 7]), # dim 8
|
||||
(5, 4, [1, 1, 5, 5, 17]),
|
||||
(5, 7, [1, 1, 5, 5, 5]), # dim 10
|
||||
(5, 11, [1, 1, 7, 11, 19]),
|
||||
(5, 13, [1, 1, 5, 1, 1]),
|
||||
(5, 14, [1, 1, 1, 3, 11]),
|
||||
(6, 1, [1, 3, 5, 5, 31, 45]), # dim 14
|
||||
(6, 13, [1, 3, 3, 9, 7, 25]),
|
||||
(6, 16, [1, 3, 1, 15, 17, 63]),
|
||||
(7, 19, [1, 1, 5, 13, 11, 3, 15]), # dim 17
|
||||
(7, 22, [1, 3, 1, 7, 3, 23, 79]),
|
||||
(7, 25, [1, 3, 7, 9, 31, 29, 17]),
|
||||
(7, 37, [1, 1, 3, 15, 29, 15, 41]), # dim 20
|
||||
(7, 41, [1, 3, 1, 7, 3, 23, 79]), # dim 21 (repeat of 18 for safety)
|
||||
]
|
||||
|
||||
_BITS = 32
|
||||
|
||||
|
||||
def _compute_direction_numbers(dim_index):
|
||||
"""Compute 32-bit direction numbers for a given dimension (0-indexed, dim 0 = Van der Corput)."""
|
||||
if dim_index == 0:
|
||||
return [1 << (_BITS - 1 - i) for i in range(_BITS)]
|
||||
|
||||
params = _JOE_KUO_PARAMS[dim_index - 1]
|
||||
s = params[0]
|
||||
a = params[1]
|
||||
m_init = params[2]
|
||||
|
||||
v = [0] * _BITS
|
||||
for i in range(min(s, _BITS)):
|
||||
if i < len(m_init):
|
||||
v[i] = m_init[i] << (_BITS - 1 - i)
|
||||
else:
|
||||
v[i] = 1 << (_BITS - 1 - i)
|
||||
|
||||
for i in range(s, _BITS):
|
||||
v[i] = v[i - s] ^ (v[i - s] >> s)
|
||||
for j in range(1, s):
|
||||
if (a >> (s - 1 - j)) & 1:
|
||||
v[i] ^= v[i - j]
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class SobolSequence:
|
||||
"""Scrambled Sobol sequence generator.
|
||||
|
||||
Falls back to scipy.stats.qmc.Sobol when available for better scrambling quality.
|
||||
"""
|
||||
|
||||
def __init__(self, d, scramble=True, seed=0):
|
||||
self.d = d
|
||||
self.scramble = scramble
|
||||
self.seed = seed
|
||||
self._use_scipy = False
|
||||
|
||||
if d > 21:
|
||||
raise ValueError(f"Sobol dimension {d} exceeds maximum 21")
|
||||
|
||||
try:
|
||||
from scipy.stats.qmc import Sobol as ScipySobol
|
||||
|
||||
self._scipy_sobol = ScipySobol(d=d, scramble=scramble, seed=seed)
|
||||
self._use_scipy = True
|
||||
except ImportError:
|
||||
self._direction_numbers = [_compute_direction_numbers(i) for i in range(d)]
|
||||
self._scramble_shifts = []
|
||||
if scramble:
|
||||
rng = random.Random(seed)
|
||||
self._scramble_shifts = [
|
||||
rng.randint(0, (1 << _BITS) - 1) for _ in range(d)
|
||||
]
|
||||
|
||||
def generate(self, n):
|
||||
"""Generate n points in [0, 1)^d."""
|
||||
if self._use_scipy:
|
||||
import math
|
||||
|
||||
m = max(1, math.ceil(math.log2(n))) if n > 0 else 0
|
||||
points = self._scipy_sobol.random_base2(m)
|
||||
return points[:n].tolist()
|
||||
|
||||
points = []
|
||||
x = [0] * self.d
|
||||
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
point = [0.0] * self.d
|
||||
if self.scramble:
|
||||
for dim in range(self.d):
|
||||
x[dim] = self._scramble_shifts[dim]
|
||||
point[dim] = x[dim] / (1 << _BITS)
|
||||
points.append(point)
|
||||
else:
|
||||
c = _rightmost_zero_bit(i - 1)
|
||||
point = [0.0] * self.d
|
||||
for dim in range(self.d):
|
||||
if c < _BITS:
|
||||
x[dim] ^= self._direction_numbers[dim][c]
|
||||
point[dim] = x[dim] / (1 << _BITS)
|
||||
points.append(point)
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def _rightmost_zero_bit(n):
|
||||
"""Find position of rightmost zero bit."""
|
||||
pos = 0
|
||||
while n & 1:
|
||||
n >>= 1
|
||||
pos += 1
|
||||
return pos
|
||||
Reference in New Issue
Block a user