mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
369 lines
10 KiB
Python
369 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Test that the C++ extract_features() in ml_heuristic.hpp produces identical
|
|
values to the Python GemmUniversalFeatureEngine.extract().
|
|
|
|
This test uses ctypes to call the C++ feature extraction compiled into a
|
|
small shared library, then compares against Python output. If compilation
|
|
fails (no HIP/ROCm), it falls back to verifying the Python feature engine
|
|
against manually computed expected values for specific test cases.
|
|
"""
|
|
|
|
import math
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
from feature_engine import (
|
|
GemmUniversalFeatureEngine,
|
|
PIPELINE_MAP,
|
|
SCHEDULER_MAP,
|
|
EPILOGUE_MAP,
|
|
LAYOUT_MAP,
|
|
)
|
|
|
|
|
|
def _compute_features_manually(
|
|
M,
|
|
N,
|
|
K,
|
|
split_k,
|
|
dtype,
|
|
layout,
|
|
tile_m,
|
|
tile_n,
|
|
tile_k,
|
|
warp_m,
|
|
warp_n,
|
|
warp_k,
|
|
warp_tile_m,
|
|
warp_tile_n,
|
|
warp_tile_k,
|
|
pipeline,
|
|
scheduler,
|
|
epilogue,
|
|
pad_m,
|
|
pad_n,
|
|
pad_k,
|
|
persistent,
|
|
hw,
|
|
):
|
|
"""Recompute features independently to verify the Python engine."""
|
|
bpe_map = {"fp8": 1.0, "fp16": 2.0, "bf16": 2.0, "fp32": 4.0}
|
|
bpe = bpe_map.get(dtype, 1.0)
|
|
|
|
log2_M = math.log2(max(M, 1))
|
|
log2_N = math.log2(max(N, 1))
|
|
log2_K = math.log2(max(K, 1))
|
|
log2_MNK = math.log2(max(M * N * K, 1))
|
|
mem = (M * K + K * N + M * N) * bpe
|
|
ai = (2.0 * M * N * K) / max(mem, 1)
|
|
|
|
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
|
|
lds_cap = 32768 if pipeline == "compv4" else hw["lds_capacity"]
|
|
|
|
ntm = math.ceil(M / max(tile_m, 1))
|
|
ntn = math.ceil(N / max(tile_n, 1))
|
|
ntk = math.ceil(K / max(tile_k, 1))
|
|
|
|
def eff(d, t):
|
|
if t <= 0:
|
|
return 1.0
|
|
r = d % t
|
|
return r / t if r > 0 else 1.0
|
|
|
|
# Problem-to-tile ratios
|
|
ratio_M_to_tile_m = M / max(tile_m, 1)
|
|
ratio_N_to_tile_n = N / max(tile_n, 1)
|
|
ratio_K_to_tile_k = K / max(tile_k, 1)
|
|
|
|
# Binary features: problem smaller than tile
|
|
problem_smaller_than_tile_m = float(M < tile_m)
|
|
problem_smaller_than_tile_n = float(N < tile_n)
|
|
problem_smaller_than_tile_k = float(K < tile_k)
|
|
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
|
|
|
|
# Padding requirement features
|
|
needs_padding_m = float(tile_m > 0 and M % tile_m != 0)
|
|
needs_padding_n = float(tile_n > 0 and N % tile_n != 0)
|
|
needs_padding_k = float(tile_k > 0 and K % tile_k != 0)
|
|
|
|
# Interaction features
|
|
has_padding_when_needed_m = float(needs_padding_m and pad_m)
|
|
has_padding_when_needed_n = float(needs_padding_n and pad_n)
|
|
has_padding_when_needed_k = float(needs_padding_k and pad_k)
|
|
|
|
# Missing padding features
|
|
missing_required_padding_m = float(needs_padding_m and not pad_m)
|
|
missing_required_padding_n = float(needs_padding_n and not pad_n)
|
|
missing_required_padding_k = float(needs_padding_k and not pad_k)
|
|
missing_any_required_padding = float(
|
|
missing_required_padding_m
|
|
or missing_required_padding_n
|
|
or missing_required_padding_k
|
|
)
|
|
|
|
return [
|
|
M, # 0
|
|
N, # 1
|
|
K, # 2
|
|
split_k, # 3
|
|
log2_M, # 4
|
|
log2_N, # 5
|
|
log2_K, # 6
|
|
log2_MNK, # 7
|
|
ai, # 8
|
|
M / max(N, 1), # 9 (aspect_ratio_mn)
|
|
M / max(K, 1), # 10 (aspect_ratio_mk)
|
|
N / max(K, 1), # 11 (aspect_ratio_nk)
|
|
LAYOUT_MAP.get(layout, 0), # 12
|
|
tile_m, # 13
|
|
tile_n, # 14
|
|
tile_k, # 15
|
|
warp_m, # 16
|
|
warp_n, # 17
|
|
warp_k, # 18
|
|
warp_tile_m, # 19
|
|
warp_tile_n, # 20
|
|
warp_tile_k, # 21
|
|
PIPELINE_MAP.get(pipeline, 0), # 22
|
|
SCHEDULER_MAP.get(scheduler, 0), # 23
|
|
EPILOGUE_MAP.get(epilogue, 0), # 24
|
|
float(pad_m), # 25
|
|
float(pad_n), # 26
|
|
float(pad_k), # 27
|
|
float(persistent), # 28
|
|
warp_m * warp_n * warp_k, # 29 (num_warps)
|
|
tile_m * tile_n * tile_k, # 30 (tile_volume)
|
|
tile_m * tile_n, # 31 (tile_mn)
|
|
lds_est, # 32 (lds_usage_estimate)
|
|
lds_est / max(lds_cap, 1), # 33 (lds_usage_ratio)
|
|
ntm, # 34 (num_tiles_m)
|
|
ntn, # 35 (num_tiles_n)
|
|
ntk, # 36 (num_tiles_k)
|
|
ntm * ntn, # 37 (total_output_tiles)
|
|
eff(M, tile_m), # 38 (tile_eff_m)
|
|
eff(N, tile_n), # 39 (tile_eff_n)
|
|
eff(K, tile_k), # 40 (tile_eff_k)
|
|
eff(M, tile_m)
|
|
* eff(N, tile_n)
|
|
* eff(K, tile_k), # 41 (overall_tile_efficiency)
|
|
ntm * ntn / max(hw["num_cus"], 1), # 42 (cu_utilization)
|
|
ratio_M_to_tile_m, # 43
|
|
ratio_N_to_tile_n, # 44
|
|
ratio_K_to_tile_k, # 45
|
|
problem_smaller_than_tile_m, # 46
|
|
problem_smaller_than_tile_n, # 47
|
|
problem_smaller_than_tile_k, # 48
|
|
any_dim_too_small, # 49
|
|
needs_padding_m, # 50
|
|
needs_padding_n, # 51
|
|
needs_padding_k, # 52
|
|
has_padding_when_needed_m, # 53
|
|
has_padding_when_needed_n, # 54
|
|
has_padding_when_needed_k, # 55
|
|
missing_required_padding_m, # 56
|
|
missing_required_padding_n, # 57
|
|
missing_required_padding_k, # 58
|
|
missing_any_required_padding, # 59
|
|
hw["num_cus"], # 60
|
|
hw["simds_per_cu"], # 61
|
|
hw["num_cus"] * hw["simds_per_cu"], # 62 (total_simds)
|
|
hw["shader_engines"], # 63
|
|
hw["max_clock_mhz"], # 64
|
|
hw["max_waves_per_cu"], # 65
|
|
hw["wavefront_size"], # 66
|
|
hw["lds_capacity"], # 67
|
|
hw["l1_cache_kb"], # 68
|
|
hw["l2_cache_kb"], # 69
|
|
hw["l3_cache_kb"], # 70
|
|
hw["num_xcd"], # 71
|
|
]
|
|
|
|
|
|
TEST_CASES = [
|
|
{
|
|
"problem": {
|
|
"m": 1024,
|
|
"n": 1024,
|
|
"k": 1024,
|
|
"split_k": 1,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
},
|
|
"kernel": {
|
|
"tile_m": 128,
|
|
"tile_n": 128,
|
|
"tile_k": 64,
|
|
"warp_m": 2,
|
|
"warp_n": 2,
|
|
"warp_k": 1,
|
|
"warp_tile_m": 32,
|
|
"warp_tile_n": 32,
|
|
"warp_tile_k": 16,
|
|
"pipeline": "compv3",
|
|
"scheduler": "intrawave",
|
|
"epilogue": "cshuffle",
|
|
"pad_m": False,
|
|
"pad_n": False,
|
|
"pad_k": False,
|
|
"persistent": False,
|
|
},
|
|
},
|
|
{
|
|
"problem": {
|
|
"m": 1,
|
|
"n": 4096,
|
|
"k": 4096,
|
|
"split_k": 1,
|
|
"dtype": "fp8",
|
|
"layout": "rcr",
|
|
},
|
|
"kernel": {
|
|
"tile_m": 64,
|
|
"tile_n": 64,
|
|
"tile_k": 128,
|
|
"warp_m": 1,
|
|
"warp_n": 4,
|
|
"warp_k": 1,
|
|
"warp_tile_m": 16,
|
|
"warp_tile_n": 16,
|
|
"warp_tile_k": 128,
|
|
"pipeline": "compv4",
|
|
"scheduler": "interwave",
|
|
"epilogue": "default",
|
|
"pad_m": True,
|
|
"pad_n": True,
|
|
"pad_k": True,
|
|
"persistent": True,
|
|
},
|
|
},
|
|
{
|
|
"problem": {
|
|
"m": 20480,
|
|
"n": 7168,
|
|
"k": 256,
|
|
"split_k": 1,
|
|
"dtype": "fp16",
|
|
"layout": "rrr",
|
|
},
|
|
"kernel": {
|
|
"tile_m": 256,
|
|
"tile_n": 256,
|
|
"tile_k": 32,
|
|
"warp_m": 4,
|
|
"warp_n": 1,
|
|
"warp_k": 1,
|
|
"warp_tile_m": 32,
|
|
"warp_tile_n": 32,
|
|
"warp_tile_k": 16,
|
|
"pipeline": "mem",
|
|
"scheduler": "interwave",
|
|
"epilogue": "cshuffle",
|
|
"pad_m": False,
|
|
"pad_n": False,
|
|
"pad_k": False,
|
|
"persistent": False,
|
|
},
|
|
},
|
|
]
|
|
|
|
HW = {
|
|
"num_cus": 256,
|
|
"simds_per_cu": 4,
|
|
"shader_engines": 32,
|
|
"max_clock_mhz": 2400,
|
|
"max_waves_per_cu": 32,
|
|
"wavefront_size": 64,
|
|
"lds_capacity": 65536,
|
|
"l1_cache_kb": 32,
|
|
"l2_cache_kb": 4096,
|
|
"l3_cache_kb": 262144,
|
|
"num_xcd": 8,
|
|
}
|
|
|
|
|
|
class TestFeatureParity:
|
|
"""Verify Python feature engine matches manual computation (C++ uses same logic)."""
|
|
|
|
@pytest.fixture
|
|
def fe(self):
|
|
return GemmUniversalFeatureEngine(**HW)
|
|
|
|
@pytest.mark.parametrize("case_idx", range(len(TEST_CASES)))
|
|
def test_python_matches_manual(self, fe, case_idx):
|
|
case = TEST_CASES[case_idx]
|
|
prob = case["problem"]
|
|
kern = case["kernel"]
|
|
|
|
py_features = fe.extract(prob, kern)
|
|
|
|
manual = _compute_features_manually(
|
|
prob["m"],
|
|
prob["n"],
|
|
prob["k"],
|
|
prob["split_k"],
|
|
prob["dtype"],
|
|
prob["layout"],
|
|
kern["tile_m"],
|
|
kern["tile_n"],
|
|
kern["tile_k"],
|
|
kern["warp_m"],
|
|
kern["warp_n"],
|
|
kern["warp_k"],
|
|
kern["warp_tile_m"],
|
|
kern["warp_tile_n"],
|
|
kern["warp_tile_k"],
|
|
kern["pipeline"],
|
|
kern["scheduler"],
|
|
kern["epilogue"],
|
|
kern["pad_m"],
|
|
kern["pad_n"],
|
|
kern["pad_k"],
|
|
kern["persistent"],
|
|
HW,
|
|
)
|
|
|
|
manual_arr = np.array(manual, dtype=np.float64)
|
|
assert len(py_features) == len(manual_arr) == 72
|
|
|
|
for i in range(72):
|
|
assert py_features[i] == pytest.approx(
|
|
manual_arr[i], rel=1e-10, abs=1e-15
|
|
), (
|
|
f"Feature {i} ({fe.get_feature_names()[i]}): Python={py_features[i]}, Manual={manual_arr[i]}"
|
|
)
|
|
|
|
def test_feature_count(self, fe):
|
|
assert len(fe.get_feature_names()) == 72
|
|
|
|
def test_encoding_maps_match_cpp(self):
|
|
"""The C++ encode_* functions must use the same mapping as Python.
|
|
|
|
PIPELINE_MAP was extended for grouped-conv suffix-aware kernels with
|
|
``basic_v1`` and ``compv6``; the original GEMM ids (0-4) are
|
|
preserved so existing GEMM models keep loading unchanged.
|
|
"""
|
|
assert PIPELINE_MAP == {
|
|
"compv3": 0,
|
|
"compv4": 1,
|
|
"compv5": 2,
|
|
"mem": 3,
|
|
"preshufflev2": 4,
|
|
"basic_v1": 5,
|
|
"compv6": 6,
|
|
}
|
|
assert SCHEDULER_MAP == {"intrawave": 0, "interwave": 1}
|
|
assert EPILOGUE_MAP == {"default": 0, "cshuffle": 1}
|
|
assert LAYOUT_MAP == {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|