mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-21 07:37:38 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
8.5 KiB
8.5 KiB
GEMM Python Examples
CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.
Main Documentation: Dispatcher README | Examples Overview
Quick Start
Build Library
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build Python library (kernels generated automatically)
make dispatcher_gemm_lib -j$(nproc)
Run Examples
cd /path/to/composable_kernel/dispatcher
python3 examples/gemm/python/01_basic_gemm.py
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py
python3 examples/gemm/python/08_heuristics.py
Examples
| Example | Description |
|---|---|
| 01_basic_gemm.py | Basic GEMM with multi-kernel support |
| 02_batch_gemm.py | Batched GEMM operations |
| 03_benchmark.py | Performance benchmarking |
| 04_validation.py | CPU reference validation |
| 05_numpy_integration.py | NumPy array integration |
| 06_json_export.py | Registry JSON export |
| 07_stress_test.py | Multi-kernel stress testing |
| 08_heuristics.py | Heuristic-based kernel selection |
| 09_multi_registry.py | Multiple registries |
| 10_advanced_benchmark.py | Advanced benchmark with full control |
| 11_json_import.py | Import kernels from JSON |
Example Details
01_basic_gemm.py - Basic GEMM
Demonstrates the Python API with multi-kernel support:
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
# Define multiple kernel configurations
kernels = [
KernelConfig(
tile_m=128, tile_n=128, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv3", scheduler="intrawave"
),
KernelConfig(
tile_m=256, tile_n=256, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv4", scheduler="intrawave"
),
]
# Display configurations
print_kernel_config_table(kernels)
# Set up dispatcher with all kernels
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
# Run GEMM
elapsed_ms = run_gemm(lib, M, N, K, ...)
02_batch_gemm.py - Batch GEMM
Batched matrix multiplication:
- Multiple independent GEMM operations
- Batch dimension handling
03_benchmark.py - Benchmarking
Performance measurement:
- GPU timing
- TFLOPS calculation
- Multiple iterations
04_validation.py - Validation
Correctness verification:
- NumPy reference implementation
- Tolerance-based validation
- Error reporting
05_numpy_integration.py - NumPy Integration
Seamless NumPy integration:
- NumPy arrays to GPU buffers
- Results back to NumPy
- Automatic type conversion
06_json_export.py - JSON Export
Registry serialization for tool integration:
- Export kernel configurations
- Machine-readable format
07_stress_test.py - Stress Testing
Comprehensive multi-kernel stress testing:
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
# Define 48 unique kernel configurations
kernels = [
KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
# ... many more configurations
]
# Test each kernel
for i, kernel in enumerate(kernels):
lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel
print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")
Features:
- 48 unique kernel configurations
- Various tile sizes, pipelines, and schedulers
- Per-kernel validation with unique random seeds
- Performance reporting
08_heuristics.py - Heuristic Selection
Custom kernel selection based on problem characteristics:
# Define kernel pools for different strategies
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]
# Size-based heuristic
def size_based_heuristic(M, N, K):
if M * N < 512 * 512:
return SMALL_KERNELS
else:
return LARGE_KERNELS
# Strategy-based selection
def compute_strategy():
return COMPUTE_KERNELS # Optimized for compute-bound problems
def memory_strategy():
return MEMORY_KERNELS # Optimized for memory-bound problems
# Test different strategies
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
kernels = strategy(M, N, K)
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
elapsed_ms = run_gemm(lib, M, N, K, ...)
Features:
- 24 kernel configurations across 6 categories
- Size-based heuristic (small vs large)
- Optimization strategies (compute, memory, latency)
- Performance comparison across strategies
09_multi_registry.py - Multiple Registries
Separate registries for different workloads:
- Compute-optimized registry
- Latency-optimized registry
- Dynamic registry selection
10_advanced_benchmark.py - Advanced Benchmark
Full control over benchmark parameters:
- Warmup iterations
- Benchmark iterations
- Statistical analysis
11_json_import.py - JSON Import
Import kernel configurations from JSON:
- External configuration files
- Dynamic kernel loading
Utility Module: ctypes_utils.py
from ctypes_utils import (
KernelConfig, # Single kernel configuration
setup_gemm_dispatcher, # Set up dispatcher with kernels
print_kernel_config_table, # Display kernel configurations
Dispatcher, # High-level dispatcher
Registry, # Kernel registry
Validator, # Validation utilities
)
KernelConfig
config = KernelConfig(
# Tile sizes
tile_m=256, tile_n=256, tile_k=32,
# Wave configuration
wave_m=2, wave_n=2, wave_k=1,
# Warp tile sizes
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
# Pipeline and scheduler
pipeline="compv4", # "compv3" or "compv4"
scheduler="intrawave", # "intrawave" or "interwave"
# Optional
epilogue="default",
padding=True,
double_buffer=True,
)
setup_gemm_dispatcher
# Single kernel
lib, dispatcher, registry = setup_gemm_dispatcher(config)
# Multiple kernels
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])
# With auto-rebuild
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)
print_kernel_config_table
kernels = [config1, config2, config3]
print_kernel_config_table(kernels)
# Output:
# +----+-------+-------+-------+--------+-----------+
# | # | Tile | Wave | Warp | Pipe | Scheduler |
# +----+-------+-------+-------+--------+-----------+
# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
# +----+-------+-------+-------+--------+-----------+
GPU Memory Management
import ctypes
import numpy as np
# Load HIP library
hip = ctypes.CDLL("libamdhip64.so")
# Allocate GPU memory
gpu_ptr = ctypes.c_void_p()
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)
# Copy to GPU (1 = hipMemcpyHostToDevice)
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)
# Copy back (2 = hipMemcpyDeviceToHost)
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)
# Free
hip.hipFree(gpu_ptr)
Performance Testing
Test compilation performance with different kernel counts:
# Test with 10 kernels (~15s compile time)
python3 01_basic_gemm.py --num-kernels 10
# Test with 20 kernels (~25s compile time)
python3 01_basic_gemm.py --num-kernels 20
# Test with 48 kernels (~50s compile time)
python3 01_basic_gemm.py --num-kernels 48
Compilation time scales roughly linearly with kernel count.