Files
composable_kernel/dispatcher/examples/grouped_conv/python/03_bwd_data.py
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

216 lines
7.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 03: Backward Data Convolution (2D + 3D)
dX = ConvBwdData(dY, W)
Declares backward-data kernels with explicit parameters,
builds a registry, JIT compiles, runs on GPU, and validates
against a CPU reference.
Usage:
python3 03_bwd_data.py
"""
import sys
import argparse
import time
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
def cpu_conv2d_bwd_data(dy, wei, prob):
"""CPU ref: compute dX from dY and W."""
N, Ho, Wo, G, Kpg = dy.shape
_, _, Y, X, C = wei.shape
Hi, Wi = prob.Hi, prob.Wi
dx = np.zeros((N, Hi, Wi, G, C), dtype=np.float32)
for n in range(N):
for g in range(G):
for hi in range(Hi):
for wi in range(Wi):
for c in range(C):
s = 0.0
for y in range(Y):
for x in range(X):
ho = hi + prob.pad_h - y
wo = wi + prob.pad_w - x
if ho % prob.stride_h == 0 and wo % prob.stride_w == 0:
ho //= prob.stride_h
wo //= prob.stride_w
if 0 <= ho < Ho and 0 <= wo < Wo:
for k in range(Kpg):
s += float(dy[n, ho, wo, g, k]) * float(
wei[g, k, y, x, c]
)
dx[n, hi, wi, g, c] = s
return dx
def main():
parser = argparse.ArgumentParser(description="Backward Data (2D + 3D)")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--workers", type=int, default=0)
args = parser.parse_args()
arch = args.arch
print("=" * 70)
print("Example 03: Backward Data Convolution (2D + 3D)")
print("=" * 70)
print(f"\n Arch: {arch}, Dtype: {args.dtype}")
print(" dX = ConvBwdData(dY, W)")
# =========================================================================
# Step 1: Declare bwd_data kernels
# =========================================================================
print("\n--- Step 1: Declare BwdData Kernels ---")
reg = GroupedConvRegistry("bwd_data_conv")
# BwdData 2D: compv3, 64x128x64 tile, wave 2x2x1, warp 32x32x16
# Constraint: tile_m == wave_m * warp_tile_m (small M handled by kPadM=True)
reg.add(
GroupedConvKernelConfig(
variant="bwd_data",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=64, # = wave_m(2) * warp_tile_m(32)
tile_n=128,
tile_k=64,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
# BwdData 3D: compv3, 16x64x128 tile
reg.add(
GroupedConvKernelConfig(
variant="bwd_data",
ndim_spatial=3,
arch=arch,
dtype=args.dtype,
tile_m=16, # = wave_m(1) * warp_tile_m(16)
tile_n=64,
tile_k=128,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
reg.print_registry()
# =========================================================================
# Step 2: JIT build
# =========================================================================
print("\n--- Step 2: JIT Build ---")
workers = args.workers if args.workers > 0 else None
t0 = time.perf_counter()
runners = reg.build(verbose=False, max_workers=workers)
jit_s = time.perf_counter() - t0
print(f" Built {len(runners)} runners in {jit_s:.1f}s")
if ("bwd_data", 2) not in runners:
print(" ERROR: bwd_data 2D JIT failed")
return 1
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 3: BwdData 2D -- GPU + CPU reference
# =========================================================================
print("\n--- Step 3: Backward Data 2D ---")
prob = GroupedConvProblem(
N=1, C=32, K=32, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="bwd_data"
)
prob.print_problem()
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype)
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np_dtype)
res = runners[("bwd_data", 2)].run(dy, w, prob)
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}")
ref = cpu_conv2d_bwd_data(dy, w, prob)
diff = np.abs(res.output.astype(np.float32) - ref)
match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.1)
print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}")
# =========================================================================
# Step 4: BwdData 3D -- GPU + non-zero check
# =========================================================================
ok_3d = True
if ("bwd_data", 3) in runners:
print("\n--- Step 4: Backward Data 3D ---")
prob3 = GroupedConvProblem(
N=1,
C=32,
K=32,
Di=6,
Hi=6,
Wi=6,
Z=3,
Y=3,
X=3,
pad_d=1,
pad_h=1,
pad_w=1,
direction="bwd_data",
)
dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype)
w3 = np.random.uniform(-0.5, 0.5, prob3.weight_shape()).astype(np_dtype)
res3 = runners[("bwd_data", 3)].run(dy3, w3, prob3)
nz = np.count_nonzero(res3.output)
ok_3d = res3.success and nz > 0
print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}")
for r in runners.values():
r.cleanup()
passed = res.success and match_2d and ok_3d
print("\n" + "=" * 70)
print(f" BwdData 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)")
print(f" BwdData 3D: {'PASS' if ok_3d else 'FAIL'}")
print(f" Status: {'PASS' if passed else 'FAIL'}")
print("=" * 70)
return 0 if passed else 1
if __name__ == "__main__":
sys.exit(main())