mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
402 lines
13 KiB
Python
402 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 11: Test All Pipeline + Scheduler Combinations
|
|
|
|
Tests all 8 pipelines with both intrawave and interwave schedulers
|
|
for all convolution variants to determine which combinations work.
|
|
|
|
Usage:
|
|
python3 11_test_schedulers.py
|
|
python3 11_test_schedulers.py --arch gfx942
|
|
python3 11_test_schedulers.py --variant forward
|
|
"""
|
|
|
|
import sys
|
|
import argparse
|
|
import time
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import json
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
|
|
from grouped_conv_utils import (
|
|
GroupedConvKernelConfig,
|
|
GroupedConvProblem,
|
|
GroupedConvRegistry,
|
|
detect_gpu_arch,
|
|
)
|
|
|
|
# All pipelines from unified_grouped_conv_codegen.py
|
|
ALL_PIPELINES = [
|
|
"basic_v1",
|
|
"mem",
|
|
"compv3",
|
|
"compv4",
|
|
"compv5",
|
|
"compv6",
|
|
"comp_async",
|
|
"basic_async_v1",
|
|
]
|
|
|
|
# Both schedulers
|
|
ALL_SCHEDULERS = ["intrawave", "interwave"]
|
|
|
|
# Pipelines that require DoubleSmemBuffer=true (enforced by static_assert in
|
|
# the pipeline headers). Building these with dsb=false is a loud compile error.
|
|
PIPELINES_REQUIRING_DSB = {"compv4", "comp_async"}
|
|
|
|
|
|
def test_pipeline_scheduler(pipeline, scheduler, variant, arch, dtype, ndim=2):
|
|
"""
|
|
Test if a pipeline+scheduler+variant combination builds and runs successfully.
|
|
|
|
Args:
|
|
pipeline: Pipeline name (e.g., "compv3", "mem")
|
|
scheduler: Scheduler type ("intrawave" or "interwave")
|
|
variant: Convolution variant (forward, bwd_data, bwd_weight)
|
|
arch: GPU architecture (e.g., "gfx950")
|
|
dtype: Data type (fp16, bf16)
|
|
ndim: Spatial dimensions (2 or 3)
|
|
|
|
Returns:
|
|
dict with keys: pipeline, scheduler, variant, ndim, build_success, run_success, error_msg
|
|
"""
|
|
result = {
|
|
"pipeline": pipeline,
|
|
"scheduler": scheduler,
|
|
"variant": variant,
|
|
"ndim": ndim,
|
|
"arch": arch,
|
|
"dtype": dtype,
|
|
"build_success": False,
|
|
"run_success": False,
|
|
"error_msg": None,
|
|
"time_ms": None,
|
|
"tflops": None,
|
|
}
|
|
|
|
try:
|
|
# Create registry with single kernel config
|
|
reg = GroupedConvRegistry(f"{variant}_{pipeline}_{scheduler}_{ndim}d")
|
|
|
|
# Use a simple, safe tile config: 16x64x64
|
|
# wave 1x4x1, warp 16x16x16
|
|
config = GroupedConvKernelConfig(
|
|
variant=variant,
|
|
ndim_spatial=ndim,
|
|
arch=arch,
|
|
dtype=dtype,
|
|
tile_m=16,
|
|
tile_n=64,
|
|
tile_k=64,
|
|
wave_m=1,
|
|
wave_n=4,
|
|
wave_k=1,
|
|
warp_tile_m=16,
|
|
warp_tile_n=16,
|
|
warp_tile_k=16,
|
|
pipeline=pipeline,
|
|
scheduler=scheduler, # Test scheduler here
|
|
epilogue="cshuffle" if pipeline not in ["mem"] else "default",
|
|
vector_size_a=4,
|
|
vector_size_b=8,
|
|
vector_size_c=8,
|
|
block_per_cu=1,
|
|
# compv4/comp_async require DoubleSmemBuffer=true (loud
|
|
# static_assert otherwise); other pipelines do not.
|
|
double_smem_buffer=(pipeline in PIPELINES_REQUIRING_DSB),
|
|
)
|
|
|
|
reg.add(config)
|
|
|
|
# Try to build
|
|
try:
|
|
runners = reg.build(verbose=False, max_workers=1)
|
|
key = (variant, ndim)
|
|
|
|
if key in runners:
|
|
result["build_success"] = True
|
|
|
|
# Try to run
|
|
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
|
|
|
if ndim == 2:
|
|
prob = GroupedConvProblem(
|
|
N=1,
|
|
C=64,
|
|
K=64,
|
|
Hi=8,
|
|
Wi=8,
|
|
Y=3,
|
|
X=3,
|
|
pad_h=1,
|
|
pad_w=1,
|
|
direction=variant,
|
|
)
|
|
else: # 3D
|
|
prob = GroupedConvProblem(
|
|
N=1,
|
|
C=64,
|
|
K=64,
|
|
Di=4,
|
|
Hi=8,
|
|
Wi=8,
|
|
Z=3,
|
|
Y=3,
|
|
X=3,
|
|
pad_d=1,
|
|
pad_h=1,
|
|
pad_w=1,
|
|
direction=variant,
|
|
)
|
|
|
|
# Generate inputs
|
|
if variant == "forward":
|
|
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
|
|
np_dtype
|
|
)
|
|
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
|
|
np_dtype
|
|
)
|
|
res = runners[key].run(x, w, prob)
|
|
elif variant == "bwd_data":
|
|
# Runner contract: input_np=dY, weight_np=W for bwd_data
|
|
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
|
|
np_dtype
|
|
)
|
|
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
|
|
np_dtype
|
|
)
|
|
res = runners[key].run(dy, w, prob)
|
|
elif variant == "bwd_weight":
|
|
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
|
|
np_dtype
|
|
)
|
|
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
|
|
np_dtype
|
|
)
|
|
res = runners[key].run(x, dy, prob)
|
|
|
|
if res.success and np.count_nonzero(res.output) > 0:
|
|
result["run_success"] = True
|
|
result["time_ms"] = res.time_ms
|
|
result["tflops"] = res.tflops
|
|
else:
|
|
result["error_msg"] = "Kernel ran but produced zero output"
|
|
|
|
# Cleanup
|
|
runners[key].cleanup()
|
|
else:
|
|
result["error_msg"] = "Kernel not in runners (build failed)"
|
|
|
|
except Exception as e:
|
|
result["error_msg"] = f"Build exception: {str(e)}"
|
|
|
|
except Exception as e:
|
|
result["error_msg"] = f"Setup exception: {str(e)}"
|
|
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Test All Pipeline + Scheduler Combinations"
|
|
)
|
|
parser.add_argument("--arch", default=detect_gpu_arch())
|
|
parser.add_argument("--dtype", default="bf16", choices=["fp16", "bf16"])
|
|
parser.add_argument(
|
|
"--variant",
|
|
default="all",
|
|
choices=["all", "forward", "bwd_data", "bwd_weight"],
|
|
help="Variant to test (default: all)",
|
|
)
|
|
parser.add_argument(
|
|
"--ndim",
|
|
type=int,
|
|
default=2,
|
|
choices=[2, 3],
|
|
help="Spatial dimensions to test (default: 2)",
|
|
)
|
|
parser.add_argument(
|
|
"--scheduler",
|
|
default="all",
|
|
choices=["all", "intrawave", "interwave"],
|
|
help="Scheduler to test (default: all)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
default="scheduler_test_results.json",
|
|
help="Output JSON file (default: scheduler_test_results.json)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
arch = args.arch
|
|
print("=" * 80)
|
|
print("Test All Pipeline + Scheduler Combinations")
|
|
print("=" * 80)
|
|
print(f"Arch: {arch}, Dtype: {args.dtype}, NDim: {args.ndim}D")
|
|
print()
|
|
|
|
# Determine variants to test
|
|
if args.variant == "all":
|
|
variants = ["forward", "bwd_data", "bwd_weight"]
|
|
else:
|
|
variants = [args.variant]
|
|
|
|
# Determine schedulers to test
|
|
if args.scheduler == "all":
|
|
schedulers = ALL_SCHEDULERS
|
|
else:
|
|
schedulers = [args.scheduler]
|
|
|
|
# Run tests
|
|
all_results = []
|
|
|
|
for variant in variants:
|
|
print(f"\n{'=' * 80}")
|
|
print(f"Testing {variant.upper()} ({args.ndim}D)")
|
|
print(f"{'=' * 80}")
|
|
print()
|
|
|
|
print(
|
|
f"{'Pipeline':<20} {'Scheduler':<12} {'Build':<8} {'Run':<8} {'Time (ms)':<12} {'TFLOPS':<10}"
|
|
)
|
|
print("-" * 80)
|
|
|
|
for pipeline in ALL_PIPELINES:
|
|
for scheduler in schedulers:
|
|
result = test_pipeline_scheduler(
|
|
pipeline, scheduler, variant, arch, args.dtype, args.ndim
|
|
)
|
|
all_results.append(result)
|
|
|
|
build_status = "✓" if result["build_success"] else "✗"
|
|
run_status = "✓" if result["run_success"] else "✗"
|
|
time_str = (
|
|
f"{result['time_ms']:.4f}"
|
|
if result["time_ms"] is not None
|
|
else "-"
|
|
)
|
|
tflops_str = (
|
|
f"{result['tflops']:.2f}" if result["tflops"] is not None else "-"
|
|
)
|
|
|
|
print(
|
|
f"{pipeline:<20} {scheduler:<12} {build_status:<8} {run_status:<8} {time_str:<12} {tflops_str:<10}"
|
|
)
|
|
|
|
if result["error_msg"] and not result["run_success"]:
|
|
print(f" → {result['error_msg']}")
|
|
|
|
print()
|
|
|
|
# Summarize results by scheduler
|
|
print("=" * 80)
|
|
print("SUMMARY BY SCHEDULER")
|
|
print("=" * 80)
|
|
print()
|
|
|
|
for scheduler in schedulers:
|
|
print(f"\n{scheduler.upper()} Scheduler:")
|
|
print("-" * 80)
|
|
|
|
for variant in variants:
|
|
variant_results = [
|
|
r
|
|
for r in all_results
|
|
if r["variant"] == variant and r["scheduler"] == scheduler
|
|
]
|
|
successful_build = [
|
|
r["pipeline"] for r in variant_results if r["build_success"]
|
|
]
|
|
successful_run = [r["pipeline"] for r in variant_results if r["run_success"]]
|
|
|
|
print(f"\n{variant} ({args.ndim}D):")
|
|
print(f" Build success ({len(successful_build)}/8): {successful_build}")
|
|
print(f" Run success ({len(successful_run)}/8): {successful_run}")
|
|
|
|
# Overall summary
|
|
print("\n" + "=" * 80)
|
|
print("OVERALL SUMMARY")
|
|
print("=" * 80)
|
|
print()
|
|
|
|
# Per-pipeline support: a pipeline is "supported" if at least one
|
|
# scheduler runs successfully. Not every pipeline supports both
|
|
# intrawave and interwave (loud static_assert / unsupported trait
|
|
# in some pipeline headers), so we only require one to work.
|
|
per_variant_supported: dict[str, list[str]] = {}
|
|
for variant in variants:
|
|
print(f"{variant.upper()}:")
|
|
|
|
# Group by pipeline; mark as supported if any scheduler succeeded
|
|
supported_pipelines = []
|
|
per_pipeline_status = []
|
|
for pipeline in ALL_PIPELINES:
|
|
schedulers_ok = [
|
|
r["scheduler"]
|
|
for r in all_results
|
|
if r["variant"] == variant
|
|
and r["pipeline"] == pipeline
|
|
and r["run_success"]
|
|
]
|
|
if schedulers_ok:
|
|
supported_pipelines.append(pipeline)
|
|
per_pipeline_status.append((pipeline, "✓", schedulers_ok))
|
|
else:
|
|
per_pipeline_status.append((pipeline, "✗", []))
|
|
|
|
# Per-pipeline detail (any-scheduler-counts)
|
|
for pipeline, status, sched_list in per_pipeline_status:
|
|
sched_str = ",".join(sched_list) if sched_list else "none"
|
|
print(f" {pipeline:<18}: {status} via [{sched_str}]")
|
|
|
|
# Per-scheduler raw breakdown (for completeness)
|
|
for scheduler in schedulers:
|
|
variant_results = [
|
|
r
|
|
for r in all_results
|
|
if r["variant"] == variant and r["scheduler"] == scheduler
|
|
]
|
|
success_count = len([r for r in variant_results if r["run_success"]])
|
|
total = len(variant_results)
|
|
pct = (success_count / total * 100) if total > 0 else 0
|
|
print(
|
|
f" raw {scheduler:<10}: {success_count}/{total} ({pct:.0f}%) pipelines work"
|
|
)
|
|
|
|
# Any-scheduler aggregate
|
|
n_sup = len(supported_pipelines)
|
|
n_total = len(ALL_PIPELINES)
|
|
agg_pct = (n_sup / n_total * 100) if n_total > 0 else 0
|
|
agg_status = "✓" if n_sup > 0 else "✗"
|
|
print(
|
|
f" ANY scheduler : {agg_status} {n_sup}/{n_total} ({agg_pct:.0f}%) pipelines supported"
|
|
)
|
|
per_variant_supported[variant] = supported_pipelines
|
|
print()
|
|
|
|
# Save results
|
|
output_file = Path(__file__).parent / args.output
|
|
with open(output_file, "w") as f:
|
|
json.dump(all_results, f, indent=2)
|
|
|
|
print(f"Detailed results saved to: {output_file}")
|
|
print()
|
|
|
|
# Success criterion (relaxed): for each variant, at least one pipeline
|
|
# must be supported by at least one scheduler. Pipelines that fail under
|
|
# *both* schedulers are reported but don't fail the run, since some
|
|
# pipelines genuinely don't support both schedulers.
|
|
success = all(per_variant_supported.get(v) for v in variants)
|
|
return 0 if success else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|