mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
495 lines
17 KiB
Python
495 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 09: ML-Based Kernel Selection for Grouped Convolution
|
|
|
|
Uses a trained LightGBM model to select the optimal kernel for each convolution
|
|
problem. The model predicts TFLOPS for every candidate in the kernel pool and
|
|
picks the highest-scoring one, which is then invoked via the dispatcher.
|
|
|
|
This replaces hand-crafted heuristics with a data-driven approach achieving
|
|
97%+ of oracle-best TFLOPS efficiency.
|
|
|
|
Supports forward, bwd_data, and bwd_weight variants.
|
|
|
|
Complexity: *****
|
|
|
|
Prerequisites:
|
|
- Trained models in dispatcher/heuristics/models/grouped_conv_*_bf16_gfx950/
|
|
- lightgbm, pandas, numpy, pyarrow installed
|
|
- grouped_conv dispatcher built
|
|
|
|
Usage:
|
|
python3 09_ml_heuristic.py --variant forward
|
|
python3 09_ml_heuristic.py --variant bwd_data
|
|
python3 09_ml_heuristic.py --variant bwd_weight
|
|
python3 09_ml_heuristic.py --variant forward --dtype bf16 --arch gfx950
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import argparse
|
|
import json
|
|
import subprocess
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
|
|
|
|
|
|
from predict import Predictor
|
|
from feature_engine_grouped_conv import GroupedConvFeatureEngine
|
|
from grouped_conv_utils import (
|
|
GroupedConvKernelConfig,
|
|
setup_multiple_grouped_conv_dispatchers,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class KernelSpec:
|
|
"""Grouped convolution kernel specification"""
|
|
|
|
name: str
|
|
block_size: int
|
|
gemm_m_per_block: int
|
|
gemm_n_per_block: int
|
|
pipeline: str = "compv3"
|
|
|
|
def to_kernel_config(self, dtype: str = "bf16", arch: str = "gfx950", variant: str = "forward") -> GroupedConvKernelConfig:
|
|
"""Convert to GroupedConvKernelConfig for building."""
|
|
return GroupedConvKernelConfig(
|
|
variant=variant,
|
|
dtype=dtype,
|
|
ndim_spatial=2,
|
|
layout="NHWGC_KYXGC_NHWGK",
|
|
arch=arch,
|
|
tile_m=self.block_size,
|
|
tile_n=self.gemm_m_per_block,
|
|
tile_k=self.gemm_n_per_block,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_tile_m=32,
|
|
warp_tile_n=32,
|
|
warp_tile_k=8,
|
|
pipeline=self.pipeline,
|
|
scheduler="default",
|
|
epilogue="default",
|
|
pad_m=True,
|
|
pad_n=True,
|
|
pad_k=True,
|
|
)
|
|
|
|
|
|
# Kernel pools for different variants
|
|
|
|
# Forward pool: compv3, compv4, compv5 (30 kernels)
|
|
FORWARD_KERNEL_POOL = [
|
|
# Block size 16
|
|
KernelSpec("k16_64x64_v3", 16, 64, 64, "compv3"),
|
|
KernelSpec("k16_64x64_v4", 16, 64, 64, "compv4"),
|
|
KernelSpec("k16_64x64_v5", 16, 64, 64, "compv5"),
|
|
KernelSpec("k16_64x128_v3", 16, 64, 128, "compv3"),
|
|
KernelSpec("k16_64x128_v4", 16, 64, 128, "compv4"),
|
|
KernelSpec("k16_64x128_v5", 16, 64, 128, "compv5"),
|
|
# Block size 32
|
|
KernelSpec("k32_64x64_v3", 32, 64, 64, "compv3"),
|
|
KernelSpec("k32_64x64_v4", 32, 64, 64, "compv4"),
|
|
KernelSpec("k32_64x64_v5", 32, 64, 64, "compv5"),
|
|
KernelSpec("k32_64x128_v3", 32, 64, 128, "compv3"),
|
|
KernelSpec("k32_64x128_v4", 32, 64, 128, "compv4"),
|
|
KernelSpec("k32_64x128_v5", 32, 64, 128, "compv5"),
|
|
KernelSpec("k32_128x64_v3", 32, 128, 64, "compv3"),
|
|
KernelSpec("k32_128x64_v4", 32, 128, 64, "compv4"),
|
|
KernelSpec("k32_128x64_v5", 32, 128, 64, "compv5"),
|
|
# Block size 64
|
|
KernelSpec("k64_64x64_v3", 64, 64, 64, "compv3"),
|
|
KernelSpec("k64_64x64_v4", 64, 64, 64, "compv4"),
|
|
KernelSpec("k64_64x64_v5", 64, 64, 64, "compv5"),
|
|
KernelSpec("k64_64x128_v3", 64, 64, 128, "compv3"),
|
|
KernelSpec("k64_64x128_v4", 64, 64, 128, "compv4"),
|
|
KernelSpec("k64_64x128_v5", 64, 64, 128, "compv5"),
|
|
KernelSpec("k64_128x64_v3", 64, 128, 64, "compv3"),
|
|
KernelSpec("k64_128x64_v4", 64, 128, 64, "compv4"),
|
|
KernelSpec("k64_128x64_v5", 64, 128, 64, "compv5"),
|
|
# Block size 128
|
|
KernelSpec("k128_64x128_v3", 128, 64, 128, "compv3"),
|
|
KernelSpec("k128_64x128_v4", 128, 64, 128, "compv4"),
|
|
KernelSpec("k128_64x128_v5", 128, 64, 128, "compv5"),
|
|
KernelSpec("k128_128x64_v3", 128, 128, 64, "compv3"),
|
|
KernelSpec("k128_128x64_v4", 128, 128, 64, "compv4"),
|
|
KernelSpec("k128_128x64_v5", 128, 128, 64, "compv5"),
|
|
]
|
|
|
|
# Backward pool: compv3, mem (20 kernels)
|
|
BACKWARD_KERNEL_POOL = [
|
|
# Block size 16
|
|
KernelSpec("k16_64x64_v3", 16, 64, 64, "compv3"),
|
|
KernelSpec("k16_64x64_mem", 16, 64, 64, "mem"),
|
|
KernelSpec("k16_64x128_v3", 16, 64, 128, "compv3"),
|
|
KernelSpec("k16_64x128_mem", 16, 64, 128, "mem"),
|
|
# Block size 32
|
|
KernelSpec("k32_64x64_v3", 32, 64, 64, "compv3"),
|
|
KernelSpec("k32_64x64_mem", 32, 64, 64, "mem"),
|
|
KernelSpec("k32_64x128_v3", 32, 64, 128, "compv3"),
|
|
KernelSpec("k32_64x128_mem", 32, 64, 128, "mem"),
|
|
KernelSpec("k32_128x64_v3", 32, 128, 64, "compv3"),
|
|
KernelSpec("k32_128x64_mem", 32, 128, 64, "mem"),
|
|
# Block size 64
|
|
KernelSpec("k64_64x64_v3", 64, 64, 64, "compv3"),
|
|
KernelSpec("k64_64x64_mem", 64, 64, 64, "mem"),
|
|
KernelSpec("k64_64x128_v3", 64, 64, 128, "compv3"),
|
|
KernelSpec("k64_64x128_mem", 64, 64, 128, "mem"),
|
|
KernelSpec("k64_128x64_v3", 64, 128, 64, "compv3"),
|
|
KernelSpec("k64_128x64_mem", 64, 128, 64, "mem"),
|
|
# Block size 128
|
|
KernelSpec("k128_64x128_v3", 128, 64, 128, "compv3"),
|
|
KernelSpec("k128_64x128_mem", 128, 64, 128, "mem"),
|
|
KernelSpec("k128_128x64_v3", 128, 128, 64, "compv3"),
|
|
KernelSpec("k128_128x64_mem", 128, 128, 64, "mem"),
|
|
]
|
|
|
|
# Legacy name for backward compatibility
|
|
KERNEL_POOL = FORWARD_KERNEL_POOL
|
|
|
|
|
|
def spec_to_feature_dict(spec: KernelSpec, dtype: str) -> dict:
|
|
"""Convert a KernelSpec to the dict format the feature engine expects."""
|
|
return {
|
|
"kernel_name": spec.name,
|
|
"block_size": spec.block_size,
|
|
"gemm_m_per_block": spec.gemm_m_per_block,
|
|
"gemm_n_per_block": spec.gemm_n_per_block,
|
|
"pipeline": spec.pipeline,
|
|
"dtype": dtype,
|
|
}
|
|
|
|
|
|
def build_kernel(spec: KernelSpec, dtype: str, arch: str, variant: str = "forward", verbose: bool = False) -> Path:
|
|
"""Build a kernel on-demand using the dispatcher's JIT compilation.
|
|
|
|
Uses the same workflow as tile_engine benchmark:
|
|
1. Convert KernelSpec to GroupedConvKernelConfig
|
|
2. Call setup_multiple_grouped_conv_dispatchers to build
|
|
3. Return path to .so file
|
|
|
|
Returns:
|
|
Path to compiled .so file, or None if build failed
|
|
"""
|
|
kernel_config = spec.to_kernel_config(dtype=dtype, arch=arch, variant=variant)
|
|
|
|
if verbose:
|
|
print(f" Building kernel: {spec.name}")
|
|
print(f" Config: variant={variant}, tile={kernel_config.tile_str}, pipeline={kernel_config.pipeline}")
|
|
|
|
# Build kernel (returns list of paths)
|
|
lib_paths = setup_multiple_grouped_conv_dispatchers(
|
|
[kernel_config], verbose=verbose, max_workers=1
|
|
)
|
|
|
|
if not lib_paths or lib_paths[0] is None:
|
|
return None
|
|
|
|
return lib_paths[0]
|
|
|
|
|
|
def run_kernel_via_subprocess(so_path: Path, problem: dict, kernel_name: str) -> dict:
|
|
"""Run a kernel via the isolated subprocess runner.
|
|
|
|
This uses the same pattern as the tile_engine benchmark to avoid GPU context issues.
|
|
"""
|
|
script_path = Path(__file__).parent.parent.parent.parent.parent / "tile_engine" / "ops" / "grouped_conv" / "run_one_grouped_conv_kernel.py"
|
|
|
|
# Prepare input JSON
|
|
input_data = {
|
|
"so_path": str(so_path),
|
|
"problem": problem,
|
|
"kernel_name": kernel_name
|
|
}
|
|
|
|
# Set environment for Python path
|
|
env = {
|
|
"GCONV_PYPATH": str(Path(__file__).parent.parent.parent.parent / "python")
|
|
}
|
|
|
|
# Run subprocess
|
|
proc = subprocess.Popen(
|
|
[sys.executable, str(script_path)],
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
env={**os.environ, **env}
|
|
)
|
|
|
|
stdout, stderr = proc.communicate(input=json.dumps(input_data).encode())
|
|
|
|
# Parse result
|
|
try:
|
|
result = json.loads(stdout.decode().strip())
|
|
return result
|
|
except:
|
|
return {"ok": False, "error": f"Failed to parse output: {stdout.decode()}"}
|
|
|
|
|
|
def ml_select_and_run(
|
|
predictor: Predictor,
|
|
pool: List[KernelSpec],
|
|
N: int,
|
|
C: int,
|
|
K: int,
|
|
G: int,
|
|
Hi: int,
|
|
Wi: int,
|
|
Y: int,
|
|
X: int,
|
|
stride_h: int,
|
|
stride_w: int,
|
|
pad_h: int = 0,
|
|
pad_w: int = 0,
|
|
dtype: str = "bf16",
|
|
arch: str = "gfx950",
|
|
variant: str = "forward",
|
|
run_on_hw: bool = True,
|
|
) -> dict:
|
|
"""
|
|
Step 1: Call predictor to get best kernel
|
|
Step 2: Invoke dispatcher using tile_engine pattern
|
|
|
|
Returns dict with prediction and (optional) hardware results.
|
|
"""
|
|
# Step 1: Predict best kernel
|
|
problem = {
|
|
"N": N,
|
|
"C": C,
|
|
"K": K,
|
|
"G": G,
|
|
"Hi": Hi,
|
|
"Wi": Wi,
|
|
"Y": Y,
|
|
"X": X,
|
|
"stride_h": stride_h,
|
|
"stride_w": stride_w,
|
|
"pad_h": pad_h,
|
|
"pad_w": pad_w,
|
|
"dtype": dtype,
|
|
}
|
|
|
|
kernel_dicts = [spec_to_feature_dict(s, dtype) for s in pool]
|
|
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
|
|
|
if not ranked:
|
|
return {"success": False, "error": "No valid kernel predictions"}
|
|
|
|
best_name, pred_tflops = ranked[0]
|
|
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
|
|
|
result = {
|
|
"success": True,
|
|
"kernel_name": best_spec.name,
|
|
"kernel_spec": best_spec,
|
|
"predicted_tflops": pred_tflops,
|
|
}
|
|
|
|
if not run_on_hw:
|
|
return result
|
|
|
|
# Step 2: Build and run on hardware via dispatcher
|
|
# Build kernel on-demand using JIT compilation
|
|
so_path = build_kernel(best_spec, dtype, arch, variant=variant, verbose=False)
|
|
|
|
if not so_path:
|
|
result["hw_success"] = False
|
|
result["hw_error"] = f"Failed to build kernel: {best_spec.name}"
|
|
return result
|
|
|
|
# Prepare problem dict for dispatcher
|
|
problem_with_direction = {**problem, "direction": variant}
|
|
|
|
# Get kernel name from .so path (e.g., libgrouped_conv_forward_bf16_2d_16x64x128_compv3.so -> grouped_conv_...)
|
|
kernel_name = so_path.stem[3:] if so_path.stem.startswith("lib") else so_path.stem
|
|
|
|
# Run via subprocess
|
|
hw_result = run_kernel_via_subprocess(so_path, problem_with_direction, kernel_name)
|
|
|
|
if hw_result.get("ok"):
|
|
result["hw_success"] = True
|
|
result["hw_time_ms"] = hw_result["ms"]
|
|
result["hw_tflops"] = hw_result["tflops"]
|
|
else:
|
|
result["hw_success"] = False
|
|
result["hw_error"] = hw_result.get("error", "Unknown error")
|
|
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="ML-based kernel selection for grouped convolution"
|
|
)
|
|
parser.add_argument("--dtype", default="bf16", choices=["fp16", "bf16"])
|
|
parser.add_argument("--arch", default="gfx950")
|
|
parser.add_argument(
|
|
"--variant",
|
|
default="forward",
|
|
choices=["forward", "bwd_data", "bwd_weight"],
|
|
help="Convolution variant (default: forward)",
|
|
)
|
|
parser.add_argument(
|
|
"--model_dir",
|
|
default=None,
|
|
help="Model directory (default: auto-detect from variant)",
|
|
)
|
|
parser.add_argument(
|
|
"--no_run", action="store_true", help="Only predict, don't run on hardware"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Auto-detect model directory from variant if not specified
|
|
if args.model_dir is None:
|
|
model_name = f"grouped_conv_{args.variant}_bf16_{args.arch}"
|
|
args.model_dir = str(
|
|
Path(__file__).parent.parent.parent.parent
|
|
/ "heuristics"
|
|
/ "models"
|
|
/ model_name
|
|
)
|
|
|
|
# Select kernel pool based on variant
|
|
if args.variant == "forward":
|
|
kernel_pool = FORWARD_KERNEL_POOL
|
|
else:
|
|
kernel_pool = BACKWARD_KERNEL_POOL
|
|
|
|
print("=" * 80)
|
|
print(f" Example 09: ML-Based Kernel Selection for Grouped Convolution ({args.variant.upper()})")
|
|
print("=" * 80)
|
|
print(f"\n Variant: {args.variant}")
|
|
print(f" Model: {args.model_dir}")
|
|
print(f" Dtype: {args.dtype}")
|
|
print(f" Arch: {args.arch}")
|
|
print(f" Pool: {len(kernel_pool)} kernels")
|
|
|
|
# Load ML model with grouped conv feature engine
|
|
feature_engine = GroupedConvFeatureEngine()
|
|
predictor = Predictor(args.model_dir, feature_engine=feature_engine)
|
|
print(" Model loaded successfully")
|
|
|
|
# Test problems: diverse convolution shapes from MIOpen
|
|
# (N, C, K, G, Hi, Wi, Y, X, stride_h, stride_w, pad_h, pad_w)
|
|
if args.variant == "forward":
|
|
test_problems = [
|
|
# ResNet-50 layers
|
|
(1, 256, 512, 1, 56, 56, 1, 1, 2, 2, 0, 0), # stride-2 1x1 conv
|
|
(1, 128, 256, 1, 32, 32, 2, 2, 2, 2, 0, 0), # stride-2 2x2 conv
|
|
(1, 512, 256, 1, 28, 28, 1, 1, 1, 1, 0, 0), # 1x1 bottleneck
|
|
# 3x3 convolutions
|
|
(1, 128, 256, 1, 64, 64, 3, 3, 1, 1, 1, 1), # standard 3x3
|
|
(1, 64, 128, 1, 128, 128, 3, 3, 1, 1, 1, 1), # larger spatial
|
|
# Small spatial
|
|
(1, 832, 128, 1, 7, 7, 1, 1, 1, 1, 0, 0), # 7x7 input
|
|
# Large channels
|
|
(1, 1024, 512, 1, 14, 14, 1, 1, 1, 1, 0, 0), # large C/K
|
|
]
|
|
elif args.variant == "bwd_data":
|
|
test_problems = [
|
|
# Typical backward data problems (with padding for 3x3)
|
|
(32, 128, 256, 1, 28, 28, 3, 3, 1, 1, 1, 1), # 3x3 standard
|
|
(16, 256, 512, 1, 14, 14, 3, 3, 1, 1, 1, 1), # 3x3 larger channels
|
|
(64, 64, 128, 1, 56, 56, 1, 1, 1, 1, 0, 0), # 1x1 conv
|
|
(32, 512, 256, 1, 7, 7, 3, 3, 1, 1, 1, 1), # small spatial
|
|
]
|
|
else: # bwd_weight
|
|
test_problems = [
|
|
# Typical backward weight problems (with padding for 3x3)
|
|
(64, 256, 512, 1, 14, 14, 3, 3, 1, 1, 1, 1), # 3x3 standard
|
|
(32, 128, 256, 1, 28, 28, 3, 3, 1, 1, 1, 1), # 3x3 medium
|
|
(128, 64, 128, 1, 56, 56, 1, 1, 1, 1, 0, 0), # 1x1 conv
|
|
(64, 512, 1024, 1, 7, 7, 3, 3, 1, 1, 1, 1), # large channels
|
|
]
|
|
|
|
run_on_hw = not args.no_run
|
|
|
|
if run_on_hw:
|
|
header = f"{'Problem':<35} {'Selected':<22} {'Pred TFLOPS':>12} {'HW Time':>10} {'HW TFLOPS':>10} {'Status':<8}"
|
|
else:
|
|
header = f"{'Problem':<35} {'Selected':<22} {'Pred TFLOPS':>12}"
|
|
|
|
print(f"\n {header}")
|
|
print(" " + "-" * len(header))
|
|
|
|
results = []
|
|
|
|
for N, C, K, G, Hi, Wi, Y, X, sh, sw, ph, pw in test_problems:
|
|
result = ml_select_and_run(
|
|
predictor, kernel_pool, N, C, K, G, Hi, Wi, Y, X, sh, sw, ph, pw,
|
|
dtype=args.dtype, arch=args.arch, variant=args.variant, run_on_hw=run_on_hw
|
|
)
|
|
|
|
# Compute output size
|
|
Ho = (Hi + 2*ph - Y) // sh + 1
|
|
Wo = (Wi + 2*pw - X) // sw + 1
|
|
|
|
prob_str = f"C{C:4d}→K{K:4d} {Hi:3d}x{Wi:3d}→{Ho:2d}x{Wo:2d} f{Y}x{X}"
|
|
|
|
if not result["success"]:
|
|
line = f" {prob_str:<35} {'ERROR':<22} {'N/A':>12}"
|
|
print(line)
|
|
continue
|
|
|
|
line = f" {prob_str:<35} {result['kernel_name']:<22} {result['predicted_tflops']:>12.2f}"
|
|
|
|
if run_on_hw:
|
|
if result.get("hw_success"):
|
|
hw_time = result["hw_time_ms"]
|
|
hw_tflops = result["hw_tflops"]
|
|
status = "PASS"
|
|
line += f" {hw_time:>10.4f} {hw_tflops:>10.2f} {status:<8}"
|
|
results.append((prob_str, result['kernel_name'], True, hw_time, hw_tflops, result['predicted_tflops']))
|
|
else:
|
|
error = result.get("hw_error", "Unknown")
|
|
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
|
print(line)
|
|
print(f" Error: {error}")
|
|
results.append((prob_str, result['kernel_name'], False, 0, 0, result['predicted_tflops']))
|
|
continue
|
|
else:
|
|
results.append((prob_str, result['kernel_name'], True, 0, 0, result['predicted_tflops']))
|
|
|
|
print(line)
|
|
|
|
# Summary
|
|
print("\n" + "=" * 80)
|
|
print(" SUMMARY")
|
|
print("=" * 80)
|
|
|
|
if run_on_hw:
|
|
passed = sum(1 for r in results if r[2])
|
|
print(f"\n Results: {passed}/{len(results)} tests passed")
|
|
valid = [r for r in results if r[2] and r[4] > 0]
|
|
if valid:
|
|
avg_hw = sum(r[4] for r in valid) / len(valid)
|
|
avg_pred = sum(r[5] for r in valid) / len(valid)
|
|
print(f" Average HW TFLOPS: {avg_hw:.2f}")
|
|
print(f" Average Predicted TFLOPS: {avg_pred:.2f}")
|
|
print(f" Prediction Accuracy: {(avg_hw/avg_pred)*100:.1f}%")
|
|
if passed == len(results):
|
|
print("\n *** ALL TESTS PASSED ***")
|
|
else:
|
|
print(f"\n Results: {len(results)} predictions completed")
|
|
avg_pred = sum(r[5] for r in results) / len(results)
|
|
print(f" Average Predicted TFLOPS: {avg_pred:.2f}")
|
|
print("\n Note: Hardware execution disabled (--no_run)")
|
|
|
|
print("=" * 80)
|
|
return 0 if (not run_on_hw or sum(1 for r in results if r[2]) == len(results)) else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
sys.exit(main())
|