mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
[CK_TILE] Restructure Tile Engine's benchmarking and profiling (#4769) ## Motivation This PR introduces a restructure for the benchmarking and profiling aspects of CK Tile's Tile Engine, expanding on the groundwork from this previous https://github.com/ROCm/composable_kernel/pull/3434 and outlined in this [design document](https://amdcloud-my.sharepoint.com/:w:/r/personal/astharai_amd_com/Documents/Restructuring%20Tile%20Engine.docx?d=w14ea28a30718416988ed5ebb759bd3b2&csf=1&web=1&e=l3VBuX). In PR 3434, to reduce repeated code we implemented: - Base class that centralizes common functionality and provides a default implementation (Universal GEMM) - Child classes for GEMM variants override virtual functions to handle variant-specific behavior This refactoring in this PR follows the same process and should greatly reduce the duplicated code present in Tile Engine and make it simpler to add in new operations, increasing scalability. ## Technical Details The files have been refactored around new base structs for benchmarks, profiling and problem descriptions. The new base structs are: - GemmProblem - GemmBenchmark - GemmProfiler Universal GEMM, Preshuffle GEMM, and Multi-D GEMM all have child classes that will inherit from these base structs overriding only what differs per variant. All common functions across the benchmarking and profiling files have been moved into newly added common utility files under the commons/ directory. The new utility files are: - utils.hpp: common functions for the benchmarking and profiling process - benchmark_utils.py: common utility functions for the benchmark generation ## Test Plan I tested using the existing tests for Tile Engine. ## Test Result All tests passed. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
331 lines
12 KiB
Python
331 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import os
|
|
import importlib.util
|
|
from pathlib import Path
|
|
from typing import List, Dict, Tuple
|
|
|
|
|
|
# TODO: explore modularizing tile engine to avoid accessing imports like this
|
|
def _import_benchmark_utils():
|
|
"""Import benchmark utilities from commons directory."""
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
parent_dir = os.path.dirname(current_dir)
|
|
|
|
# Load the module dynamically
|
|
spec = importlib.util.spec_from_file_location(
|
|
"benchmark_utils",
|
|
os.path.join(parent_dir, "common", "benchmark_utils.py"),
|
|
)
|
|
benchmark_utils = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(benchmark_utils)
|
|
|
|
return benchmark_utils
|
|
|
|
|
|
benchmark_utils = _import_benchmark_utils()
|
|
|
|
|
|
class GemmBenchmark:
|
|
def __init__(
|
|
self, build_dir: str, verbose: bool = False, name: str = "benchmark_gemm_"
|
|
):
|
|
self.build_dir = Path(build_dir)
|
|
self.verbose = verbose
|
|
self.results = []
|
|
self.name = name
|
|
|
|
def discover_kernels(self) -> List[Path]:
|
|
"""Find all benchmark_gemm_* executables in the build directory"""
|
|
bin_dir = self.build_dir / "bin"
|
|
if not bin_dir.exists():
|
|
print(f"Error: Binary directory {bin_dir} does not exist")
|
|
return []
|
|
|
|
glob_name = f"{self.name}*"
|
|
kernels = list(bin_dir.glob(glob_name))
|
|
if self.verbose:
|
|
print(f"Found {len(kernels)} kernel executables")
|
|
for k in kernels:
|
|
print(f" - {k.name}")
|
|
return kernels
|
|
|
|
def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]:
|
|
"""Extract comprehensive kernel information from filename"""
|
|
name = kernel_path.stem
|
|
if name.startswith(self.name):
|
|
args = name[len(self.name) :]
|
|
else:
|
|
args = name
|
|
|
|
# Initialize with basic info
|
|
info = {
|
|
"executable": str(kernel_path),
|
|
"name": name,
|
|
"data_type": "unknown",
|
|
"layout": "unknown",
|
|
"pipeline": "unknown",
|
|
"scheduler": "unknown",
|
|
"epilogue": "unknown",
|
|
}
|
|
|
|
# Parse the kernel name pattern:
|
|
# benchmark_gemm_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16
|
|
parts = args.split("_")
|
|
|
|
if len(parts) >= 5:
|
|
info["data_type"] = parts[0]
|
|
info["layout"] = parts[1]
|
|
info["pipeline"] = parts[2]
|
|
info["epilogue"] = parts[3]
|
|
info["scheduler"] = parts[4]
|
|
|
|
# Extract detailed configuration from the end of the name
|
|
config_info = self.parse_detailed_config(name)
|
|
info.update(config_info)
|
|
|
|
# Generate config ID
|
|
info["config_id"] = self.generate_config_id(info)
|
|
|
|
return info
|
|
|
|
def parse_detailed_config(self, kernel_name: str) -> Dict:
|
|
"""Parse detailed configuration from kernel name"""
|
|
config = {
|
|
"tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0},
|
|
"warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0},
|
|
"warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0},
|
|
"optimization_flags": {
|
|
"pad_m": False,
|
|
"pad_n": False,
|
|
"pad_k": False,
|
|
"persistent": False,
|
|
},
|
|
}
|
|
|
|
# Split by underscore and look for patterns
|
|
parts = kernel_name.split("_")
|
|
|
|
# Look for boolean flags (sequence of True/False values)
|
|
bool_sequence = []
|
|
for i, part in enumerate(parts):
|
|
if part in ["True", "False"]:
|
|
bool_sequence.append(part == "True")
|
|
# Continue collecting consecutive boolean values
|
|
j = i + 1
|
|
while j < len(parts) and parts[j] in ["True", "False"]:
|
|
bool_sequence.append(parts[j] == "True")
|
|
j += 1
|
|
break
|
|
|
|
# Assign boolean flags if we found them
|
|
# Order: pad_m, pad_n, pad_k, persistent (4 flags total)
|
|
if len(bool_sequence) >= 4:
|
|
config["optimization_flags"]["pad_m"] = bool_sequence[0]
|
|
config["optimization_flags"]["pad_n"] = bool_sequence[1]
|
|
config["optimization_flags"]["pad_k"] = bool_sequence[2]
|
|
config["optimization_flags"]["persistent"] = bool_sequence[3]
|
|
|
|
# Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16)
|
|
# The pattern is: tile_sizes_warp_config_warp_tile
|
|
dimension_groups = []
|
|
for part in parts:
|
|
if "x" in part and len(part.split("x")) == 3:
|
|
try:
|
|
dims = [int(x) for x in part.split("x")]
|
|
if all(d > 0 for d in dims):
|
|
dimension_groups.append(dims)
|
|
except ValueError:
|
|
continue
|
|
|
|
# Assign dimensions based on order and magnitude
|
|
if len(dimension_groups) >= 3:
|
|
# Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile
|
|
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
|
|
|
|
# Largest dimensions = tile sizes
|
|
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
|
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
|
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
|
|
|
# Smallest dimensions = warp config
|
|
config["warp_config"]["warp_m"] = sorted_groups[2][0]
|
|
config["warp_config"]["warp_n"] = sorted_groups[2][1]
|
|
config["warp_config"]["warp_k"] = sorted_groups[2][2]
|
|
|
|
# Middle dimensions = warp tile
|
|
config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0]
|
|
config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1]
|
|
config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2]
|
|
elif len(dimension_groups) == 2:
|
|
# If only 2 groups, assign based on magnitude
|
|
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
|
|
|
|
# Larger = tile sizes
|
|
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
|
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
|
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
|
|
|
# Smaller = warp config
|
|
config["warp_config"]["warp_m"] = sorted_groups[1][0]
|
|
config["warp_config"]["warp_n"] = sorted_groups[1][1]
|
|
config["warp_config"]["warp_k"] = sorted_groups[1][2]
|
|
elif len(dimension_groups) == 1:
|
|
# Only one group - assume it's tile sizes
|
|
config["tile_sizes"]["tile_m"] = dimension_groups[0][0]
|
|
config["tile_sizes"]["tile_n"] = dimension_groups[0][1]
|
|
config["tile_sizes"]["tile_k"] = dimension_groups[0][2]
|
|
|
|
return config
|
|
|
|
def generate_config_id(self, info: Dict) -> str:
|
|
"""Generate a compact config ID from kernel info"""
|
|
# Create a compact identifier
|
|
parts = [
|
|
info.get("data_type", "unk"),
|
|
info.get("layout", "unk"),
|
|
info.get("pipeline", "unk"),
|
|
info.get("scheduler", "unk"),
|
|
]
|
|
|
|
# Add tile configuration if available
|
|
tile_sizes = info.get("tile_sizes", {})
|
|
if tile_sizes.get("tile_m", 0) > 0:
|
|
tile_str = (
|
|
f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}"
|
|
)
|
|
parts.append(tile_str)
|
|
|
|
# Add warp config if available
|
|
warp_config = info.get("warp_config", {})
|
|
if warp_config.get("warp_m", 0) > 0:
|
|
warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}"
|
|
parts.append(warp_str)
|
|
|
|
# Add warp tile if available
|
|
warp_tile = info.get("warp_tile", {})
|
|
if warp_tile.get("warp_tile_m", 0) > 0:
|
|
warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}"
|
|
parts.append(warp_tile_str)
|
|
|
|
return "_".join(parts)
|
|
|
|
def benchmark_problem_size(
|
|
self,
|
|
kernels: List[Path],
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
split_k: int = 1,
|
|
verify: int = 0,
|
|
warmup: int = 50,
|
|
repeat: int = 100,
|
|
flush_cache: bool = True,
|
|
rotating_count: int = 1000,
|
|
) -> List[Dict]:
|
|
"""Benchmark all kernels for a specific problem size"""
|
|
results = []
|
|
|
|
params = {
|
|
"m": m,
|
|
"n": n,
|
|
"k": k,
|
|
"split_k": split_k,
|
|
"verify": verify,
|
|
"warmup": warmup,
|
|
"repeat": repeat,
|
|
"flush_cache": str(flush_cache).lower(),
|
|
"rotating_count": rotating_count,
|
|
}
|
|
|
|
print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}")
|
|
|
|
for kernel_path in kernels:
|
|
kernel_info = self.extract_kernel_info(kernel_path)
|
|
result = benchmark_utils.run_kernel(
|
|
self.build_dir, kernel_path, params, verbose=self.verbose
|
|
)
|
|
if result:
|
|
# Create new structured result format
|
|
structured_result = {
|
|
"name": kernel_info["name"], # Add name field for compatibility
|
|
"config_id": kernel_info["config_id"],
|
|
"problem": result.get("problem", {}),
|
|
"perf_result": result.get("perf_result", {}),
|
|
"config": {
|
|
"data_type": kernel_info["data_type"],
|
|
"layout": kernel_info["layout"],
|
|
"pipeline": kernel_info["pipeline"],
|
|
"scheduler": kernel_info["scheduler"],
|
|
"epilogue": kernel_info["epilogue"],
|
|
"tile_sizes": kernel_info.get("tile_sizes", {}),
|
|
"warp_config": kernel_info.get("warp_config", {}),
|
|
"warp_tile": kernel_info.get("warp_tile", {}),
|
|
"optimization_flags": kernel_info.get("optimization_flags", {}),
|
|
},
|
|
"executable": kernel_info["executable"],
|
|
# Keep backward compatibility fields
|
|
"time_ms": result.get("time_ms", 0),
|
|
"tflops": result.get("tflops", 0),
|
|
"bandwidth_gb_s": result.get("bandwidth_gb_s", 0),
|
|
}
|
|
|
|
results.append(structured_result)
|
|
|
|
if self.verbose:
|
|
print(
|
|
f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms"
|
|
)
|
|
|
|
return results
|
|
|
|
def benchmark_sweep(
|
|
self,
|
|
problem_sizes: List[Tuple[int, int, int]],
|
|
split_k_values: List[int] = [1],
|
|
verify: bool = False,
|
|
warmup: int = 50,
|
|
repeat: int = 100,
|
|
flush_cache: bool = True,
|
|
rotating_count: int = 1000,
|
|
) -> Dict:
|
|
"""Run comprehensive benchmark sweep"""
|
|
kernels = self.discover_kernels()
|
|
if not kernels:
|
|
print("No kernels found!")
|
|
return {}
|
|
|
|
all_results = []
|
|
best_kernels = {}
|
|
|
|
for m, n, k in problem_sizes:
|
|
for split_k in split_k_values:
|
|
results = self.benchmark_problem_size(
|
|
kernels,
|
|
m,
|
|
n,
|
|
k,
|
|
split_k,
|
|
verify=2 if verify else 0,
|
|
warmup=warmup,
|
|
repeat=repeat,
|
|
flush_cache=flush_cache,
|
|
rotating_count=rotating_count,
|
|
)
|
|
|
|
all_results.extend(results)
|
|
|
|
# Find best kernel for this configuration
|
|
best = benchmark_utils.find_best_kernel(results)
|
|
if best:
|
|
key = f"m{m}_n{n}_k{k}_splitk{split_k}"
|
|
best_kernels[key] = best
|
|
print(
|
|
f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)"
|
|
)
|
|
|
|
self.results = all_results
|
|
return best_kernels
|