Files
composable_kernel/dispatcher/examples/grouped_conv/python/06_registry_json.py
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

283 lines
8.3 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 06: Registry, Heuristic Selection & JSON Export
Declares multiple kernel configurations with different tile sizes,
builds a registry, demonstrates heuristic runtime kernel selection,
JSON round-trip, and GPU execution.
Usage:
python3 06_registry_json.py
python3 06_registry_json.py --workers 4
"""
import sys
import time
import argparse
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
def conv_heuristic(problem):
spatial = problem.Ho * problem.Wo
if spatial > 400:
return ["256", "128", "64"]
return ["64", "128", "256"]
def main():
parser = argparse.ArgumentParser(description="Registry, Heuristic & JSON")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--workers", type=int, default=0)
args = parser.parse_args()
arch = args.arch
print("=" * 70)
print("Example 06: Registry, Heuristic Selection & JSON Export")
print("=" * 70)
print(f"\n Arch: {arch}, Dtype: {args.dtype}")
# Step 1: Declare kernels with full explicit parameters
print("\n--- Step 1: Declare Kernels + Build Registry ---")
reg = GroupedConvRegistry("conv_tiles")
# All tiles satisfy: tile_m == wave_m * warp_tile_m (TileGemmShape)
# Small problem-M handled by kPadM=True (default).
# Large tile: 128x128x64, wave 4x4x1, warp 32x32x16, compv3
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=128, # = wave_m(4) * warp_tile_m(32)
tile_n=128,
tile_k=64,
wave_m=4,
wave_n=4,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
num_wave_groups=1,
num_groups_to_merge=1,
)
)
# Medium tile: 64x128x64, wave 2x2x1, warp 32x32x16, compv4 (LDS 24 KiB <= 32 KiB)
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=64, # = wave_m(2) * warp_tile_m(32)
tile_n=128,
tile_k=64,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv4",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
num_wave_groups=1,
num_groups_to_merge=1,
double_smem_buffer=True, # required by compv4 pipeline
)
)
# Small tile: 16x64x128, wave 1x4x1, warp 16x16x32, compv3
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=16, # = wave_m(1) * warp_tile_m(16)
tile_n=64,
tile_k=128,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
num_wave_groups=1,
num_groups_to_merge=1,
)
)
reg.print_registry()
# Step 2: Heuristic kernel selection
print("\n--- Step 2: Heuristic Kernel Selection ---")
problems = [
(
"small_7x7",
GroupedConvProblem(
N=1,
C=512,
K=512,
Hi=7,
Wi=7,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction="forward",
),
),
(
"medium_14x14",
GroupedConvProblem(
N=1,
C=256,
K=256,
Hi=14,
Wi=14,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction="forward",
),
),
(
"large_56x56",
GroupedConvProblem(
N=1,
C=64,
K=128,
Hi=56,
Wi=56,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction="forward",
),
),
]
print(f" {'Problem':<16} {'Spatial':>8} {'Selected Kernel':<50}")
print(f" {'-' * 74}")
for label, prob in problems:
selected = reg.select(prob, heuristic=conv_heuristic)
spatial = prob.Ho * prob.Wo
sel_name = selected.name if selected else "none"
print(f" {label:<16} {spatial:>8} {sel_name:<50}")
# Step 3: JSON round-trip
print("\n--- Step 3: JSON Round-Trip ---")
json_str = reg.to_json()
print(f" Exported: {len(json_str)} bytes, {len(reg)} kernels")
imported = GroupedConvRegistry.from_json(json_str)
print(f" Imported: {len(imported)} kernels")
orig = reg.kernels[0]
imp = imported.kernels[0]
rt_ok = (
orig.vector_size_a == imp.vector_size_a
and orig.block_per_cu == imp.block_per_cu
and orig.tile_n == imp.tile_n
)
print(f" Full fields round-trip: {'OK' if rt_ok else 'FAIL'}")
# Step 4: JIT build + GPU execution
print("\n--- Step 4: JIT Build + GPU Execution ---")
workers = args.workers if args.workers > 0 else None
jit_reg = GroupedConvRegistry("jit_conv")
jit_reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=64, # = wave_m(2) * warp_tile_m(32); LDS 24 KiB <= compv4 32 KiB
tile_n=128,
tile_k=64,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
double_smem_buffer=True, # required by compv4 pipeline
pipeline="compv4",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
)
)
t0 = time.perf_counter()
runners = jit_reg.build(verbose=False, max_workers=workers)
jit_s = time.perf_counter() - t0
if ("forward", 2) not in runners:
print(" JIT build failed")
return 1
runner = runners[("forward", 2)]
print(f" JIT build: {jit_s:.3f} s")
print(f" Library: {runner.library_path}")
prob = GroupedConvProblem(
N=1, C=128, K=128, Hi=16, Wi=16, Y=3, X=3, pad_h=1, pad_w=1, direction="forward"
)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
res = runner.run(inp, wei, prob)
runner.cleanup()
if res.success:
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}")
gpu_ok = res.success
print("\n" + "=" * 70)
print(f" Registry: {len(reg)} kernels (3 tile configs)")
print(" Heuristic: spatial-based selection demonstrated")
print(f" JSON: round-trip {'OK' if rt_ok else 'FAIL'}")
print(f" GPU: {'OK' if gpu_ok else 'FAIL'}")
print(f" Status: {'PASS' if gpu_ok and rt_ok else 'FAIL'}")
print("=" * 70)
return 0 if gpu_ok and rt_ok else 1
if __name__ == "__main__":
sys.exit(main())