mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
283 lines
8.3 KiB
Python
283 lines
8.3 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 06: Registry, Heuristic Selection & JSON Export
|
|
|
|
Declares multiple kernel configurations with different tile sizes,
|
|
builds a registry, demonstrates heuristic runtime kernel selection,
|
|
JSON round-trip, and GPU execution.
|
|
|
|
Usage:
|
|
python3 06_registry_json.py
|
|
python3 06_registry_json.py --workers 4
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import argparse
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
|
|
from grouped_conv_utils import (
|
|
GroupedConvKernelConfig,
|
|
GroupedConvProblem,
|
|
GroupedConvRegistry,
|
|
detect_gpu_arch,
|
|
)
|
|
|
|
|
|
def conv_heuristic(problem):
|
|
spatial = problem.Ho * problem.Wo
|
|
if spatial > 400:
|
|
return ["256", "128", "64"]
|
|
return ["64", "128", "256"]
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Registry, Heuristic & JSON")
|
|
parser.add_argument("--arch", default=detect_gpu_arch())
|
|
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
|
|
parser.add_argument("--workers", type=int, default=0)
|
|
args = parser.parse_args()
|
|
|
|
arch = args.arch
|
|
print("=" * 70)
|
|
print("Example 06: Registry, Heuristic Selection & JSON Export")
|
|
print("=" * 70)
|
|
print(f"\n Arch: {arch}, Dtype: {args.dtype}")
|
|
|
|
# Step 1: Declare kernels with full explicit parameters
|
|
print("\n--- Step 1: Declare Kernels + Build Registry ---")
|
|
reg = GroupedConvRegistry("conv_tiles")
|
|
|
|
# All tiles satisfy: tile_m == wave_m * warp_tile_m (TileGemmShape)
|
|
# Small problem-M handled by kPadM=True (default).
|
|
|
|
# Large tile: 128x128x64, wave 4x4x1, warp 32x32x16, compv3
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="forward",
|
|
ndim_spatial=2,
|
|
arch=arch,
|
|
dtype=args.dtype,
|
|
tile_m=128, # = wave_m(4) * warp_tile_m(32)
|
|
tile_n=128,
|
|
tile_k=64,
|
|
wave_m=4,
|
|
wave_n=4,
|
|
wave_k=1,
|
|
warp_tile_m=32,
|
|
warp_tile_n=32,
|
|
warp_tile_k=16,
|
|
pipeline="compv3",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
num_wave_groups=1,
|
|
num_groups_to_merge=1,
|
|
)
|
|
)
|
|
# Medium tile: 64x128x64, wave 2x2x1, warp 32x32x16, compv4 (LDS 24 KiB <= 32 KiB)
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="forward",
|
|
ndim_spatial=2,
|
|
arch=arch,
|
|
dtype=args.dtype,
|
|
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
|
tile_n=128,
|
|
tile_k=64,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_tile_m=32,
|
|
warp_tile_n=32,
|
|
warp_tile_k=16,
|
|
pipeline="compv4",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
num_wave_groups=1,
|
|
num_groups_to_merge=1,
|
|
double_smem_buffer=True, # required by compv4 pipeline
|
|
)
|
|
)
|
|
# Small tile: 16x64x128, wave 1x4x1, warp 16x16x32, compv3
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="forward",
|
|
ndim_spatial=2,
|
|
arch=arch,
|
|
dtype=args.dtype,
|
|
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
|
tile_n=64,
|
|
tile_k=128,
|
|
wave_m=1,
|
|
wave_n=4,
|
|
wave_k=1,
|
|
warp_tile_m=16,
|
|
warp_tile_n=16,
|
|
warp_tile_k=32,
|
|
pipeline="compv3",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
num_wave_groups=1,
|
|
num_groups_to_merge=1,
|
|
)
|
|
)
|
|
reg.print_registry()
|
|
|
|
# Step 2: Heuristic kernel selection
|
|
print("\n--- Step 2: Heuristic Kernel Selection ---")
|
|
problems = [
|
|
(
|
|
"small_7x7",
|
|
GroupedConvProblem(
|
|
N=1,
|
|
C=512,
|
|
K=512,
|
|
Hi=7,
|
|
Wi=7,
|
|
Y=3,
|
|
X=3,
|
|
pad_h=1,
|
|
pad_w=1,
|
|
direction="forward",
|
|
),
|
|
),
|
|
(
|
|
"medium_14x14",
|
|
GroupedConvProblem(
|
|
N=1,
|
|
C=256,
|
|
K=256,
|
|
Hi=14,
|
|
Wi=14,
|
|
Y=3,
|
|
X=3,
|
|
pad_h=1,
|
|
pad_w=1,
|
|
direction="forward",
|
|
),
|
|
),
|
|
(
|
|
"large_56x56",
|
|
GroupedConvProblem(
|
|
N=1,
|
|
C=64,
|
|
K=128,
|
|
Hi=56,
|
|
Wi=56,
|
|
Y=3,
|
|
X=3,
|
|
pad_h=1,
|
|
pad_w=1,
|
|
direction="forward",
|
|
),
|
|
),
|
|
]
|
|
print(f" {'Problem':<16} {'Spatial':>8} {'Selected Kernel':<50}")
|
|
print(f" {'-' * 74}")
|
|
for label, prob in problems:
|
|
selected = reg.select(prob, heuristic=conv_heuristic)
|
|
spatial = prob.Ho * prob.Wo
|
|
sel_name = selected.name if selected else "none"
|
|
print(f" {label:<16} {spatial:>8} {sel_name:<50}")
|
|
|
|
# Step 3: JSON round-trip
|
|
print("\n--- Step 3: JSON Round-Trip ---")
|
|
json_str = reg.to_json()
|
|
print(f" Exported: {len(json_str)} bytes, {len(reg)} kernels")
|
|
imported = GroupedConvRegistry.from_json(json_str)
|
|
print(f" Imported: {len(imported)} kernels")
|
|
orig = reg.kernels[0]
|
|
imp = imported.kernels[0]
|
|
rt_ok = (
|
|
orig.vector_size_a == imp.vector_size_a
|
|
and orig.block_per_cu == imp.block_per_cu
|
|
and orig.tile_n == imp.tile_n
|
|
)
|
|
print(f" Full fields round-trip: {'OK' if rt_ok else 'FAIL'}")
|
|
|
|
# Step 4: JIT build + GPU execution
|
|
print("\n--- Step 4: JIT Build + GPU Execution ---")
|
|
workers = args.workers if args.workers > 0 else None
|
|
jit_reg = GroupedConvRegistry("jit_conv")
|
|
jit_reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="forward",
|
|
ndim_spatial=2,
|
|
arch=arch,
|
|
dtype=args.dtype,
|
|
tile_m=64, # = wave_m(2) * warp_tile_m(32); LDS 24 KiB <= compv4 32 KiB
|
|
tile_n=128,
|
|
tile_k=64,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_tile_m=32,
|
|
warp_tile_n=32,
|
|
warp_tile_k=16,
|
|
double_smem_buffer=True, # required by compv4 pipeline
|
|
pipeline="compv4",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
)
|
|
)
|
|
t0 = time.perf_counter()
|
|
runners = jit_reg.build(verbose=False, max_workers=workers)
|
|
jit_s = time.perf_counter() - t0
|
|
|
|
if ("forward", 2) not in runners:
|
|
print(" JIT build failed")
|
|
return 1
|
|
runner = runners[("forward", 2)]
|
|
print(f" JIT build: {jit_s:.3f} s")
|
|
print(f" Library: {runner.library_path}")
|
|
|
|
prob = GroupedConvProblem(
|
|
N=1, C=128, K=128, Hi=16, Wi=16, Y=3, X=3, pad_h=1, pad_w=1, direction="forward"
|
|
)
|
|
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
|
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
|
|
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
|
|
res = runner.run(inp, wei, prob)
|
|
runner.cleanup()
|
|
|
|
if res.success:
|
|
print(f" Time: {res.time_ms:.4f} ms")
|
|
print(f" TFLOPS: {res.tflops:.2f}")
|
|
print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}")
|
|
|
|
gpu_ok = res.success
|
|
print("\n" + "=" * 70)
|
|
print(f" Registry: {len(reg)} kernels (3 tile configs)")
|
|
print(" Heuristic: spatial-based selection demonstrated")
|
|
print(f" JSON: round-trip {'OK' if rt_ok else 'FAIL'}")
|
|
print(f" GPU: {'OK' if gpu_ok else 'FAIL'}")
|
|
print(f" Status: {'PASS' if gpu_ok and rt_ok else 'FAIL'}")
|
|
print("=" * 70)
|
|
return 0 if gpu_ok and rt_ok else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|