mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
323 lines
10 KiB
Python
323 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 05: Multi-Problem GPU Benchmark
|
|
|
|
Declares kernels with explicit tile/wave/warp/pipeline parameters for
|
|
all directions, builds registries, JIT compiles, and benchmarks across
|
|
ResNet-like problem sizes with configurable warmup/repeat.
|
|
|
|
Usage:
|
|
python3 05_benchmark.py
|
|
python3 05_benchmark.py --warmup 3 --repeat 10
|
|
python3 05_benchmark.py --workers 4
|
|
"""
|
|
|
|
import sys
|
|
import argparse
|
|
import time
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
|
|
from grouped_conv_utils import (
|
|
GroupedConvKernelConfig,
|
|
GroupedConvProblem,
|
|
GroupedConvRegistry,
|
|
detect_gpu_arch,
|
|
)
|
|
|
|
|
|
def compute_bytes(prob, dtype_bytes=2):
|
|
in_elems = 1
|
|
for d in prob.input_shape():
|
|
in_elems *= d
|
|
wei_elems = 1
|
|
for d in prob.weight_shape():
|
|
wei_elems *= d
|
|
out_elems = 1
|
|
for d in prob.output_shape():
|
|
out_elems *= d
|
|
return (in_elems + wei_elems + out_elems) * dtype_bytes
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Multi-Problem GPU Benchmark")
|
|
parser.add_argument("--arch", default=detect_gpu_arch())
|
|
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
|
|
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
|
parser.add_argument("--repeat", type=int, default=5, help="Benchmark iterations")
|
|
parser.add_argument(
|
|
"--workers", type=int, default=0, help="Max JIT workers (0=auto)"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
print("=" * 70)
|
|
print("Example 05: Multi-Problem GPU Benchmark")
|
|
print("=" * 70)
|
|
print(f"\n Arch: {args.arch}, Dtype: {args.dtype}")
|
|
print(f" Warmup: {args.warmup}, Repeat: {args.repeat}")
|
|
|
|
# =========================================================================
|
|
# Step 1: Declare all kernels with explicit parameters
|
|
# =========================================================================
|
|
print("\n--- Step 1: Declare Kernels ---")
|
|
reg = GroupedConvRegistry("benchmark")
|
|
|
|
# All tiles satisfy: tile_m == wave_m * warp_tile_m (TileGemmShape)
|
|
# Small problem-M handled by kPadM=True (default).
|
|
|
|
# Forward 2D: compv4, 64x128x64 tile (LDS 24 KiB <= 32 KiB compv4 limit)
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="forward",
|
|
ndim_spatial=2,
|
|
arch=args.arch,
|
|
dtype=args.dtype,
|
|
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
|
tile_n=128,
|
|
tile_k=64,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_tile_m=32,
|
|
warp_tile_n=32,
|
|
warp_tile_k=16,
|
|
pipeline="compv4",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
double_smem_buffer=True, # required by compv4 pipeline
|
|
)
|
|
)
|
|
# Forward 3D: compv3, 16x64x128 tile
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="forward",
|
|
ndim_spatial=3,
|
|
arch=args.arch,
|
|
dtype=args.dtype,
|
|
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
|
tile_n=64,
|
|
tile_k=128,
|
|
wave_m=1,
|
|
wave_n=4,
|
|
wave_k=1,
|
|
warp_tile_m=16,
|
|
warp_tile_n=16,
|
|
warp_tile_k=32,
|
|
pipeline="compv3",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
)
|
|
)
|
|
# BwdData 2D: compv3, 64x128x64 tile
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="bwd_data",
|
|
ndim_spatial=2,
|
|
arch=args.arch,
|
|
dtype=args.dtype,
|
|
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
|
tile_n=128,
|
|
tile_k=64,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_tile_m=32,
|
|
warp_tile_n=32,
|
|
warp_tile_k=16,
|
|
pipeline="compv3",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
)
|
|
)
|
|
# BwdWeight 2D: compv3, 64x128x64 tile
|
|
reg.add(
|
|
GroupedConvKernelConfig(
|
|
variant="bwd_weight",
|
|
ndim_spatial=2,
|
|
arch=args.arch,
|
|
dtype=args.dtype,
|
|
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
|
tile_n=128,
|
|
tile_k=64,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_tile_m=32,
|
|
warp_tile_n=32,
|
|
warp_tile_k=16,
|
|
pipeline="compv3",
|
|
scheduler="intrawave",
|
|
epilogue="cshuffle",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
)
|
|
)
|
|
reg.print_registry()
|
|
|
|
# =========================================================================
|
|
# Step 2: JIT build
|
|
# =========================================================================
|
|
print("\n--- Step 2: JIT Build ---")
|
|
workers = args.workers if args.workers > 0 else None
|
|
t0 = time.perf_counter()
|
|
runner_by_key = reg.build(verbose=False, max_workers=workers)
|
|
jit_s = time.perf_counter() - t0
|
|
|
|
for key in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)]:
|
|
tag = "OK" if key in runner_by_key else "FAILED"
|
|
print(f" {key[0]:12s} {key[1]}D: {tag}")
|
|
print(f" JIT build time: {jit_s:.3f} s")
|
|
|
|
missing = [
|
|
k
|
|
for k in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)]
|
|
if k not in runner_by_key
|
|
]
|
|
if missing:
|
|
print(f"\n ERROR: missing {missing}")
|
|
return 1
|
|
|
|
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
|
|
|
def bench_run(runner, inp, wei, prob):
|
|
for _ in range(args.warmup):
|
|
runner.run(inp, wei, prob)
|
|
times = []
|
|
for _ in range(args.repeat):
|
|
r = runner.run(inp, wei, prob)
|
|
if r.success:
|
|
times.append(r.time_ms)
|
|
if not times:
|
|
return 0.0, 0.0
|
|
return min(times), sum(times) / len(times)
|
|
|
|
# =========================================================================
|
|
# Step 3: 2D Forward benchmark
|
|
# =========================================================================
|
|
print("\n--- Step 3: Forward 2D Benchmark ---")
|
|
print(
|
|
f"{'Problem':<18} {'N':>3} {'C':>4} {'K':>4} {'H':>3} {'W':>3} "
|
|
f"{'F':>3} {'Min(ms)':>9} {'Avg(ms)':>9} {'TFLOPS':>8} {'GB/s':>8}"
|
|
)
|
|
print("-" * 85)
|
|
|
|
all_ok = True
|
|
for label, n, c, k, h, w, y, x, s, p in [
|
|
("ResNet-stage2", 1, 64, 64, 56, 56, 3, 3, 1, 1),
|
|
("ResNet-stage3", 1, 128, 128, 28, 28, 3, 3, 1, 1),
|
|
("ResNet-stage4", 1, 256, 256, 14, 14, 3, 3, 1, 1),
|
|
("ResNet-stage5", 1, 512, 512, 7, 7, 3, 3, 1, 1),
|
|
("Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1, 1, 0),
|
|
("Batch-8", 8, 64, 128, 56, 56, 3, 3, 1, 1),
|
|
("Batch-32", 32, 64, 128, 56, 56, 3, 3, 1, 1),
|
|
]:
|
|
prob = GroupedConvProblem(
|
|
N=n,
|
|
C=c,
|
|
K=k,
|
|
Hi=h,
|
|
Wi=w,
|
|
Y=y,
|
|
X=x,
|
|
stride_h=s,
|
|
stride_w=s,
|
|
pad_h=p,
|
|
pad_w=p,
|
|
direction="forward",
|
|
)
|
|
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
|
|
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
|
|
min_ms, avg_ms = bench_run(runner_by_key[("forward", 2)], inp, wei, prob)
|
|
if avg_ms > 0:
|
|
tflops = prob.flops / (avg_ms * 1e9)
|
|
bw = compute_bytes(prob) / (avg_ms * 1e6)
|
|
print(
|
|
f"{label:<18} {n:>3} {c:>4} {k:>4} {h:>3} {w:>3} "
|
|
f"{y}x{x} {min_ms:>9.4f} {avg_ms:>9.4f} {tflops:>8.2f} {bw:>8.1f}"
|
|
)
|
|
else:
|
|
all_ok = False
|
|
|
|
# =========================================================================
|
|
# Step 4: 3D Forward
|
|
# =========================================================================
|
|
print("\n--- Step 4: Forward 3D ---")
|
|
for label, n, c, k, d, h, w, z, y, x in [
|
|
("3D-small", 1, 64, 64, 8, 16, 16, 3, 3, 3),
|
|
("3D-medium", 1, 64, 128, 16, 32, 32, 3, 3, 3),
|
|
]:
|
|
prob = GroupedConvProblem(
|
|
N=n, C=c, K=k, Di=d, Hi=h, Wi=w, Z=z, Y=y, X=x, direction="forward"
|
|
)
|
|
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
|
|
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
|
|
min_ms, avg_ms = bench_run(runner_by_key[("forward", 3)], inp, wei, prob)
|
|
if avg_ms > 0:
|
|
tflops = prob.flops / (avg_ms * 1e9)
|
|
print(f" {label:<14} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS")
|
|
|
|
# =========================================================================
|
|
# Step 5: Backward directions
|
|
# =========================================================================
|
|
print("\n--- Step 5: Backward Directions ---")
|
|
for label, direction in [
|
|
("bwd_data ResNet-s3", "bwd_data"),
|
|
("bwd_weight ResNet-s3", "bwd_weight"),
|
|
]:
|
|
prob = GroupedConvProblem(
|
|
N=1,
|
|
C=128,
|
|
K=128,
|
|
Hi=28,
|
|
Wi=28,
|
|
Y=3,
|
|
X=3,
|
|
stride_h=1,
|
|
stride_w=1,
|
|
pad_h=1,
|
|
pad_w=1,
|
|
direction=direction,
|
|
)
|
|
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
|
|
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
|
|
min_ms, avg_ms = bench_run(runner_by_key[(direction, 2)], inp, wei, prob)
|
|
if avg_ms > 0:
|
|
tflops = prob.flops / (avg_ms * 1e9)
|
|
print(
|
|
f" {label:<14} {direction:>12} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS"
|
|
)
|
|
|
|
for runner in runner_by_key.values():
|
|
runner.cleanup()
|
|
|
|
print("\n" + "=" * 70)
|
|
print(f" JIT build: {jit_s:.3f} s")
|
|
print(f" Warmup: {args.warmup}, Repeat: {args.repeat}")
|
|
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
|
|
print("=" * 70)
|
|
return 0 if all_ok else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|