mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher (#5260) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The CK Tile dispatcher currently supports GEMM and Grouped Convolution but has no support for Fused Multi-Head Attention (FMHA). The example/ck_tile/01_fmha folder contains a comprehensive FMHA implementation with forward, backward, split-KV, paged-KV, append-KV, and batch-prefill kernels across multiple GPU architectures — but there is no unified dispatch layer for it. This PR ports the FMHA stack into the dispatcher, following the same architectural patterns established by GEMM and Grouped Convolution, enabling runtime kernel selection, JIT compilation from Python, and a declarative C++ example flow. Autotuning heuristics to follow. ## Technical Details This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring GEMM's layered architecture. Seven new C++ runtime headers provide type definitions (coexisting with upstream headers via __has_include, requiring zero modifications to example/ck_tile/01_fmha/), a problem builder with 18+ setters, Signature + Algorithm kernel key matching, a virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard support and named tile/wave/warp setters, arch-aware registry with JSON export, and a dispatcher with seqtune-aware selection, configurable timing, and multi-stage execution plans for split-KV (two-stage) and backward (three-stage). The codegen pipeline is driven by a fmha_arch_specs.json capturing per-arch tile tables and pipeline constraints for five architectures (gfx90a/942/950/1100/1201), migrated from hardcoded logic in 01_fmha/codegen/, with supporting modules for C++ symbol mappings, validation rules, and named receipt profiles (ck_default, flash, pytorch, aiter, fp32, fp8). Python integration (fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel multi-kernel builds, HIP memory management via ctypes, tolerance-based validation, and a NumPy CPU reference with GQA support. Twenty-seven C++ and thirty-two Python examples cover the full feature surface — forward, split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill, fp8, logits soft cap, sink tokens, and parameter sweeps — all JIT-compiled on the fly. ## Test Plan Seven test files cover the runtime types, codegen, and end-to-end correctness. C++ unit tests validate the problem builder, dispatcher planning (single-stage for forward/paged-KV/append-KV; multi-stage for split-KV and backward), registry operations, and the kernel-set declaration macro. Python unit tests verify codegen emission, profile filtering, and 15 validation rules for masks, hdim constraints, and pipeline requirements. GPU execution validation in 01_basic_fmha --validate reports zero errors across 65,536 elements with max absolute error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py) runs 14 configurations through both the upstream tile_example_fmha_fwd and the dispatcher, comparing exit codes to confirm behavioral parity — all 14 match. ## Test Result The C++ smoke test builds and passes all 9 compiled examples, and a Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching up to 375 TFLOPS at seqlen 2048. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
717 lines
20 KiB
Python
717 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 08: Custom Heuristics
|
|
|
|
Demonstrates custom kernel selection heuristics based on problem characteristics.
|
|
|
|
This example shows how to:
|
|
1. Define multiple kernel configurations for different workloads
|
|
2. Implement custom heuristics to select the best kernel
|
|
3. Test heuristic selection across different problem sizes
|
|
|
|
Heuristic strategies:
|
|
- Size-based: Small tiles for small problems, large tiles for large problems
|
|
- Compute-bound: Maximize compute utilization for large matrices
|
|
- Memory-bound: Optimize memory access for bandwidth-limited cases
|
|
- Latency-focused: Minimize kernel launch overhead for small problems
|
|
|
|
|
|
Usage:
|
|
python3 08_heuristics.py
|
|
python3 08_heuristics.py --help
|
|
python3 08_heuristics.py --strategy compute
|
|
python3 08_heuristics.py --dtype bf16
|
|
"""
|
|
|
|
import sys
|
|
import argparse
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
from enum import Enum
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
import numpy as np
|
|
|
|
from ctypes_utils import (
|
|
KernelConfig,
|
|
setup_gemm_dispatcher,
|
|
cleanup_gemm,
|
|
reset_for_example,
|
|
detect_gpu_arch,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Kernel Specifications
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class KernelSpec:
|
|
"""Kernel specification with metadata for heuristic selection"""
|
|
|
|
name: str
|
|
tile_m: int
|
|
tile_n: int
|
|
tile_k: int
|
|
pipeline: str = "compv3"
|
|
scheduler: str = "intrawave"
|
|
# Metadata for heuristics
|
|
category: str = "balanced" # small, balanced, large, compute, memory
|
|
min_problem_size: int = 0
|
|
max_problem_size: int = float("inf")
|
|
|
|
|
|
# Define kernel pool for heuristic selection (20+ kernels)
|
|
KERNEL_POOL = [
|
|
# ==========================================================================
|
|
# SMALL TILES - Low latency, good for small problems
|
|
# ==========================================================================
|
|
KernelSpec(
|
|
"small_64x64_k32",
|
|
64,
|
|
64,
|
|
32,
|
|
"compv3",
|
|
"intrawave",
|
|
category="small",
|
|
max_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"small_64x64_k64",
|
|
64,
|
|
64,
|
|
64,
|
|
"compv3",
|
|
"intrawave",
|
|
category="small",
|
|
max_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"small_64x64_v4",
|
|
64,
|
|
64,
|
|
32,
|
|
"compv4",
|
|
"intrawave",
|
|
category="small",
|
|
max_problem_size=256 * 256,
|
|
),
|
|
# ==========================================================================
|
|
# MEDIUM TILES - Balanced performance
|
|
# ==========================================================================
|
|
KernelSpec(
|
|
"medium_128x128_k32",
|
|
128,
|
|
128,
|
|
32,
|
|
"compv3",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=128 * 128,
|
|
max_problem_size=2048 * 2048,
|
|
),
|
|
KernelSpec(
|
|
"medium_128x128_k64",
|
|
128,
|
|
128,
|
|
64,
|
|
"compv3",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"medium_128x128_k128",
|
|
128,
|
|
128,
|
|
128,
|
|
"compv3",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"medium_128x128_v4_k32",
|
|
128,
|
|
128,
|
|
32,
|
|
"compv4",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"medium_128x128_v4_k64",
|
|
128,
|
|
128,
|
|
64,
|
|
"compv4",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
# Rectangular medium tiles
|
|
KernelSpec(
|
|
"rect_64x128_k32",
|
|
64,
|
|
128,
|
|
32,
|
|
"compv3",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=128 * 128,
|
|
),
|
|
KernelSpec(
|
|
"rect_128x64_k32",
|
|
128,
|
|
64,
|
|
32,
|
|
"compv3",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=128 * 128,
|
|
),
|
|
KernelSpec(
|
|
"rect_64x128_k64",
|
|
64,
|
|
128,
|
|
64,
|
|
"compv3",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"rect_128x64_k64",
|
|
128,
|
|
64,
|
|
64,
|
|
"compv3",
|
|
"intrawave",
|
|
category="balanced",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
# ==========================================================================
|
|
# LARGE TILES - High throughput for large problems
|
|
# ==========================================================================
|
|
KernelSpec(
|
|
"large_256x128_k32",
|
|
256,
|
|
128,
|
|
32,
|
|
"compv3",
|
|
"intrawave",
|
|
category="large",
|
|
min_problem_size=512 * 512,
|
|
),
|
|
KernelSpec(
|
|
"large_256x128_k64",
|
|
256,
|
|
128,
|
|
64,
|
|
"compv3",
|
|
"intrawave",
|
|
category="large",
|
|
min_problem_size=512 * 512,
|
|
),
|
|
KernelSpec(
|
|
"large_128x256_k32",
|
|
128,
|
|
256,
|
|
32,
|
|
"compv3",
|
|
"intrawave",
|
|
category="large",
|
|
min_problem_size=512 * 512,
|
|
),
|
|
KernelSpec(
|
|
"large_128x256_k64",
|
|
128,
|
|
256,
|
|
64,
|
|
"compv3",
|
|
"intrawave",
|
|
category="large",
|
|
min_problem_size=512 * 512,
|
|
),
|
|
KernelSpec(
|
|
"large_256x256_k32",
|
|
256,
|
|
256,
|
|
32,
|
|
"compv3",
|
|
"intrawave",
|
|
category="large",
|
|
min_problem_size=1024 * 1024,
|
|
),
|
|
KernelSpec(
|
|
"large_256x256_k64",
|
|
256,
|
|
256,
|
|
64,
|
|
"compv3",
|
|
"intrawave",
|
|
category="large",
|
|
min_problem_size=1024 * 1024,
|
|
),
|
|
# ==========================================================================
|
|
# COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads
|
|
# ==========================================================================
|
|
KernelSpec(
|
|
"compute_128x128_v4_k32",
|
|
128,
|
|
128,
|
|
32,
|
|
"compv4",
|
|
"intrawave",
|
|
category="compute",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"compute_128x128_v4_k64",
|
|
128,
|
|
128,
|
|
64,
|
|
"compv4",
|
|
"intrawave",
|
|
category="compute",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"compute_256x128_v4",
|
|
256,
|
|
128,
|
|
64,
|
|
"compv4",
|
|
"intrawave",
|
|
category="compute",
|
|
min_problem_size=512 * 512,
|
|
),
|
|
KernelSpec(
|
|
"compute_256x256_v4",
|
|
256,
|
|
256,
|
|
64,
|
|
"compv4",
|
|
"intrawave",
|
|
category="compute",
|
|
min_problem_size=1024 * 1024,
|
|
),
|
|
# ==========================================================================
|
|
# MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads
|
|
# ==========================================================================
|
|
KernelSpec(
|
|
"memory_128x128_k16",
|
|
128,
|
|
128,
|
|
16,
|
|
"compv3",
|
|
"intrawave",
|
|
category="memory",
|
|
min_problem_size=256 * 256,
|
|
),
|
|
KernelSpec(
|
|
"memory_64x128_k16",
|
|
64,
|
|
128,
|
|
16,
|
|
"compv3",
|
|
"intrawave",
|
|
category="memory",
|
|
min_problem_size=128 * 128,
|
|
),
|
|
]
|
|
|
|
|
|
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
|
"""Create KernelConfig from spec"""
|
|
warp_m = 16 if spec.tile_m <= 64 else 32
|
|
warp_n = 16 if spec.tile_n <= 64 else 32
|
|
|
|
return KernelConfig(
|
|
dtype_a=dtype,
|
|
dtype_b=dtype,
|
|
dtype_c=dtype,
|
|
dtype_acc="fp32",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=spec.tile_m,
|
|
tile_n=spec.tile_n,
|
|
tile_k=spec.tile_k,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_m=warp_m,
|
|
warp_n=warp_n,
|
|
warp_k=16,
|
|
pipeline=spec.pipeline,
|
|
scheduler=spec.scheduler,
|
|
epilogue="cshuffle",
|
|
gfx_arch=arch,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Heuristic Strategies
|
|
# =============================================================================
|
|
|
|
|
|
class HeuristicStrategy(Enum):
|
|
SIZE_BASED = "size"
|
|
COMPUTE_BOUND = "compute"
|
|
MEMORY_BOUND = "memory"
|
|
LATENCY_FOCUSED = "latency"
|
|
|
|
|
|
def size_based_heuristic(
|
|
M: int, N: int, K: int, kernels: List[KernelSpec]
|
|
) -> KernelSpec:
|
|
"""
|
|
Select kernel based on problem size.
|
|
- Small problems: Use small tiles for low latency
|
|
- Medium problems: Use balanced tiles
|
|
- Large problems: Use large tiles for high throughput
|
|
|
|
Also considers K dimension for tile_k selection.
|
|
"""
|
|
total_elements = M * N
|
|
|
|
# Filter by problem size constraints
|
|
candidates = [
|
|
k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size
|
|
]
|
|
|
|
if not candidates:
|
|
candidates = kernels # Fall back to all kernels
|
|
|
|
# Determine target category based on problem size
|
|
if total_elements < 256 * 256:
|
|
target_category = "small"
|
|
elif total_elements < 1024 * 1024:
|
|
target_category = "balanced"
|
|
else:
|
|
target_category = "large"
|
|
|
|
# Filter by category if possible
|
|
category_candidates = [k for k in candidates if k.category == target_category]
|
|
if category_candidates:
|
|
candidates = category_candidates
|
|
|
|
# Select best tile_k based on K dimension
|
|
# Prefer tile_k that divides K well
|
|
def tile_k_score(k):
|
|
if K % k.tile_k == 0:
|
|
return 0 # Perfect division
|
|
return K % k.tile_k # Remainder (lower is better)
|
|
|
|
# Sort by tile_k fit, then by tile size
|
|
candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n))
|
|
|
|
return candidates[0]
|
|
|
|
|
|
def compute_bound_heuristic(
|
|
M: int, N: int, K: int, kernels: List[KernelSpec]
|
|
) -> KernelSpec:
|
|
"""
|
|
Select kernel optimized for compute-bound workloads.
|
|
Prefers compv4 pipeline and larger tiles.
|
|
Selects based on problem size to maximize compute utilization.
|
|
"""
|
|
total_elements = M * N
|
|
|
|
# Prefer compute category kernels
|
|
compute_kernels = [k for k in kernels if k.category == "compute"]
|
|
|
|
if not compute_kernels:
|
|
# Fall back to compv4 kernels
|
|
compute_kernels = [k for k in kernels if k.pipeline == "compv4"]
|
|
|
|
if not compute_kernels:
|
|
compute_kernels = kernels
|
|
|
|
# Filter by problem size
|
|
valid = [k for k in compute_kernels if k.min_problem_size <= total_elements]
|
|
if valid:
|
|
compute_kernels = valid
|
|
|
|
# For large problems, prefer larger tiles
|
|
if total_elements >= 1024 * 1024:
|
|
return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k)
|
|
else:
|
|
# For smaller problems, prefer medium tiles
|
|
return min(
|
|
compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128)
|
|
)
|
|
|
|
|
|
def memory_bound_heuristic(
|
|
M: int, N: int, K: int, kernels: List[KernelSpec]
|
|
) -> KernelSpec:
|
|
"""
|
|
Select kernel optimized for memory-bound workloads.
|
|
Prefers smaller tile_k for better memory access patterns.
|
|
"""
|
|
# Prefer memory category kernels first
|
|
memory_kernels = [k for k in kernels if k.category == "memory"]
|
|
if memory_kernels:
|
|
# Select based on problem size
|
|
total = M * N
|
|
if total < 512 * 512:
|
|
return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
|
|
return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
|
|
|
|
# Fall back to balanced with smaller tile_k
|
|
balanced = [k for k in kernels if k.category == "balanced"]
|
|
if balanced:
|
|
# Prefer smaller tile_k for memory-bound
|
|
return min(balanced, key=lambda k: k.tile_k)
|
|
|
|
# Fall back to medium-sized tile with small tile_k
|
|
return min(
|
|
kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128))
|
|
)
|
|
|
|
|
|
def latency_focused_heuristic(
|
|
M: int, N: int, K: int, kernels: List[KernelSpec]
|
|
) -> KernelSpec:
|
|
"""
|
|
Select kernel optimized for low latency.
|
|
Prefers smaller tiles and compv4 for faster execution.
|
|
"""
|
|
# Prefer small category
|
|
small_kernels = [k for k in kernels if k.category == "small"]
|
|
|
|
if small_kernels:
|
|
# Among small kernels, prefer compv4 for lower latency
|
|
v4_small = [k for k in small_kernels if k.pipeline == "compv4"]
|
|
if v4_small:
|
|
return v4_small[0]
|
|
return small_kernels[0]
|
|
|
|
# Fall back to smallest tile with compv4 if available
|
|
all_v4 = [k for k in kernels if k.pipeline == "compv4"]
|
|
if all_v4:
|
|
return min(all_v4, key=lambda k: k.tile_m * k.tile_n)
|
|
|
|
# Fall back to smallest tile
|
|
return min(kernels, key=lambda k: k.tile_m * k.tile_n)
|
|
|
|
|
|
HEURISTICS = {
|
|
HeuristicStrategy.SIZE_BASED: size_based_heuristic,
|
|
HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic,
|
|
HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic,
|
|
HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Main
|
|
# =============================================================================
|
|
|
|
|
|
def print_kernel_pool(kernels: List[KernelSpec]):
|
|
"""Print available kernels"""
|
|
print("\n" + "=" * 75)
|
|
print(" KERNEL POOL")
|
|
print("=" * 75)
|
|
print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}")
|
|
print(" " + "-" * 73)
|
|
|
|
for i, k in enumerate(kernels, 1):
|
|
tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}"
|
|
print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}")
|
|
|
|
print(" " + "-" * 73)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Custom Heuristics Example - intelligent kernel selection",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
python3 08_heuristics.py # Default size-based heuristic
|
|
python3 08_heuristics.py --strategy compute # Compute-bound heuristic
|
|
python3 08_heuristics.py --strategy memory # Memory-bound heuristic
|
|
python3 08_heuristics.py --strategy latency # Latency-focused heuristic
|
|
python3 08_heuristics.py --dtype bf16 # BF16 mode
|
|
""",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
default="fp16",
|
|
choices=["fp16", "bf16", "fp32"],
|
|
help="Data type (default: fp16)",
|
|
)
|
|
parser.add_argument(
|
|
"--strategy",
|
|
default="size",
|
|
choices=["size", "compute", "memory", "latency"],
|
|
help="Heuristic strategy (default: size)",
|
|
)
|
|
parser.add_argument(
|
|
"--arch",
|
|
default=detect_gpu_arch(),
|
|
help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
print("=" * 75)
|
|
print("Example 08: Custom Heuristics")
|
|
print("=" * 75)
|
|
|
|
# Map strategy string to enum
|
|
strategy_map = {
|
|
"size": HeuristicStrategy.SIZE_BASED,
|
|
"compute": HeuristicStrategy.COMPUTE_BOUND,
|
|
"memory": HeuristicStrategy.MEMORY_BOUND,
|
|
"latency": HeuristicStrategy.LATENCY_FOCUSED,
|
|
}
|
|
strategy = strategy_map[args.strategy]
|
|
heuristic_fn = HEURISTICS[strategy]
|
|
|
|
print(f"\n Strategy: {strategy.value}")
|
|
print(f" Data type: {args.dtype}")
|
|
|
|
# Print kernel pool
|
|
print_kernel_pool(KERNEL_POOL)
|
|
|
|
# =========================================================================
|
|
# Test heuristic selection across different problem sizes
|
|
# =========================================================================
|
|
print("\n" + "=" * 75)
|
|
print(" HEURISTIC SELECTION TEST")
|
|
print("=" * 75)
|
|
|
|
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
|
|
|
test_sizes = [
|
|
(128, 128, 64), # Small
|
|
(256, 256, 128), # Small-medium
|
|
(512, 512, 256), # Medium
|
|
(1024, 1024, 512), # Medium-large
|
|
(2048, 2048, 1024), # Large
|
|
]
|
|
|
|
print(
|
|
f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
|
|
)
|
|
print(" " + "-" * 78)
|
|
|
|
results = []
|
|
|
|
for M, N, K in test_sizes:
|
|
# Use heuristic to select kernel
|
|
selected_spec = heuristic_fn(M, N, K, KERNEL_POOL)
|
|
|
|
# Create config and setup
|
|
config = create_kernel_config(selected_spec, args.dtype, args.arch)
|
|
|
|
setup = setup_gemm_dispatcher(
|
|
config=config,
|
|
registry_name=f"heuristic_{selected_spec.name}",
|
|
verbose=False,
|
|
auto_rebuild=True,
|
|
)
|
|
|
|
size_str = f"{M}x{N}x{K}"
|
|
|
|
if not setup.success:
|
|
print(
|
|
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
|
)
|
|
results.append((size_str, selected_spec.name, False, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
dispatcher = setup.dispatcher
|
|
|
|
if not dispatcher.is_supported(M, N, K):
|
|
print(
|
|
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
|
|
)
|
|
results.append((size_str, selected_spec.name, False, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
# Run GEMM
|
|
np.random.seed(42)
|
|
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
|
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
|
|
|
result = dispatcher.run(A, B, M, N, K)
|
|
|
|
if not result.success:
|
|
print(
|
|
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
|
)
|
|
results.append((size_str, selected_spec.name, False, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
# Validate
|
|
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
|
max_err = np.max(np.abs(result.output - C_ref))
|
|
passed = max_err < 1e-2
|
|
|
|
status = "PASS" if passed else "FAIL"
|
|
print(
|
|
f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
|
|
)
|
|
results.append(
|
|
(size_str, selected_spec.name, passed, result.time_ms, result.tflops)
|
|
)
|
|
|
|
cleanup_gemm()
|
|
|
|
# =========================================================================
|
|
# Summary
|
|
# =========================================================================
|
|
print("\n" + "=" * 75)
|
|
print(" SUMMARY")
|
|
print("=" * 75)
|
|
|
|
passed = sum(1 for r in results if r[2])
|
|
failed = len(results) - passed
|
|
|
|
print(f"\n Strategy: {strategy.value}")
|
|
print(f" Results: {passed}/{len(results)} tests passed")
|
|
|
|
# Show kernel selection distribution
|
|
kernel_usage = {}
|
|
for r in results:
|
|
kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1
|
|
|
|
print("\n Kernel Selection Distribution:")
|
|
for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]):
|
|
print(f" {kernel}: {count} times")
|
|
|
|
if results:
|
|
valid_results = [r for r in results if r[2]]
|
|
if valid_results:
|
|
avg_tflops = sum(r[4] for r in valid_results) / len(valid_results)
|
|
print(f"\n Average TFLOPS: {avg_tflops:.2f}")
|
|
|
|
if failed == 0:
|
|
print("\n *** ALL TESTS PASSED ***")
|
|
else:
|
|
print(f"\n *** {failed} TESTS FAILED ***")
|
|
|
|
print("=" * 75)
|
|
|
|
return 0 if failed == 0 else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|