Files
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

300 lines
8.5 KiB
Markdown

# GEMM Python Examples
CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
## Quick Start
### Build Library
```bash
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build Python library (kernels generated automatically)
make dispatcher_gemm_lib -j$(nproc)
```
### Run Examples
```bash
cd /path/to/composable_kernel/dispatcher
python3 examples/gemm/python/01_basic_gemm.py
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py
python3 examples/gemm/python/08_heuristics.py
```
## Examples
| Example | Description |
|---------|-------------|
| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support |
| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations |
| [03_benchmark.py](03_benchmark.py) | Performance benchmarking |
| [04_validation.py](04_validation.py) | CPU reference validation |
| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration |
| [06_json_export.py](06_json_export.py) | Registry JSON export |
| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing |
| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection |
| [09_multi_registry.py](09_multi_registry.py) | Multiple registries |
| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control |
| [11_json_import.py](11_json_import.py) | Import kernels from JSON |
## Example Details
### 01_basic_gemm.py - Basic GEMM
Demonstrates the Python API with multi-kernel support:
```python
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
# Define multiple kernel configurations
kernels = [
KernelConfig(
tile_m=128, tile_n=128, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv3", scheduler="intrawave"
),
KernelConfig(
tile_m=256, tile_n=256, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv4", scheduler="intrawave"
),
]
# Display configurations
print_kernel_config_table(kernels)
# Set up dispatcher with all kernels
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
# Run GEMM
elapsed_ms = run_gemm(lib, M, N, K, ...)
```
### 02_batch_gemm.py - Batch GEMM
Batched matrix multiplication:
- Multiple independent GEMM operations
- Batch dimension handling
### 03_benchmark.py - Benchmarking
Performance measurement:
- GPU timing
- TFLOPS calculation
- Multiple iterations
### 04_validation.py - Validation
Correctness verification:
- NumPy reference implementation
- Tolerance-based validation
- Error reporting
### 05_numpy_integration.py - NumPy Integration
Seamless NumPy integration:
- NumPy arrays to GPU buffers
- Results back to NumPy
- Automatic type conversion
### 06_json_export.py - JSON Export
Registry serialization for tool integration:
- Export kernel configurations
- Machine-readable format
### 07_stress_test.py - Stress Testing
Comprehensive multi-kernel stress testing:
```python
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
# Define 48 unique kernel configurations
kernels = [
KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
# ... many more configurations
]
# Test each kernel
for i, kernel in enumerate(kernels):
lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel
print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")
```
**Features:**
- 48 unique kernel configurations
- Various tile sizes, pipelines, and schedulers
- Per-kernel validation with unique random seeds
- Performance reporting
### 08_heuristics.py - Heuristic Selection
Custom kernel selection based on problem characteristics:
```python
# Define kernel pools for different strategies
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]
# Size-based heuristic
def size_based_heuristic(M, N, K):
if M * N < 512 * 512:
return SMALL_KERNELS
else:
return LARGE_KERNELS
# Strategy-based selection
def compute_strategy():
return COMPUTE_KERNELS # Optimized for compute-bound problems
def memory_strategy():
return MEMORY_KERNELS # Optimized for memory-bound problems
# Test different strategies
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
kernels = strategy(M, N, K)
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
elapsed_ms = run_gemm(lib, M, N, K, ...)
```
**Features:**
- 24 kernel configurations across 6 categories
- Size-based heuristic (small vs large)
- Optimization strategies (compute, memory, latency)
- Performance comparison across strategies
### 09_multi_registry.py - Multiple Registries
Separate registries for different workloads:
- Compute-optimized registry
- Latency-optimized registry
- Dynamic registry selection
### 10_advanced_benchmark.py - Advanced Benchmark
Full control over benchmark parameters:
- Warmup iterations
- Benchmark iterations
- Statistical analysis
### 11_json_import.py - JSON Import
Import kernel configurations from JSON:
- External configuration files
- Dynamic kernel loading
## Utility Module: ctypes_utils.py
```python
from ctypes_utils import (
KernelConfig, # Single kernel configuration
setup_gemm_dispatcher, # Set up dispatcher with kernels
print_kernel_config_table, # Display kernel configurations
Dispatcher, # High-level dispatcher
Registry, # Kernel registry
Validator, # Validation utilities
)
```
### KernelConfig
```python
config = KernelConfig(
# Tile sizes
tile_m=256, tile_n=256, tile_k=32,
# Wave configuration
wave_m=2, wave_n=2, wave_k=1,
# Warp tile sizes
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
# Pipeline and scheduler
pipeline="compv4", # "compv3" or "compv4"
scheduler="intrawave", # "intrawave" or "interwave"
# Optional
epilogue="default",
padding=True,
double_buffer=True,
)
```
### setup_gemm_dispatcher
```python
# Single kernel
lib, dispatcher, registry = setup_gemm_dispatcher(config)
# Multiple kernels
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])
# With auto-rebuild
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)
```
### print_kernel_config_table
```python
kernels = [config1, config2, config3]
print_kernel_config_table(kernels)
# Output:
# +----+-------+-------+-------+--------+-----------+
# | # | Tile | Wave | Warp | Pipe | Scheduler |
# +----+-------+-------+-------+--------+-----------+
# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
# +----+-------+-------+-------+--------+-----------+
```
### GPU Memory Management
```python
import ctypes
import numpy as np
# Load HIP library
hip = ctypes.CDLL("libamdhip64.so")
# Allocate GPU memory
gpu_ptr = ctypes.c_void_p()
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)
# Copy to GPU (1 = hipMemcpyHostToDevice)
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)
# Copy back (2 = hipMemcpyDeviceToHost)
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)
# Free
hip.hipFree(gpu_ptr)
```
## Performance Testing
Test compilation performance with different kernel counts:
```bash
# Test with 10 kernels (~15s compile time)
python3 01_basic_gemm.py --num-kernels 10
# Test with 20 kernels (~25s compile time)
python3 01_basic_gemm.py --num-kernels 20
# Test with 48 kernels (~50s compile time)
python3 01_basic_gemm.py --num-kernels 48
```
Compilation time scales roughly linearly with kernel count.
## Related Documentation
- [C++ GEMM Examples](../cpp/README.md)
- [Python Conv Examples](../../conv/python/README.md)
- [Main Dispatcher README](../../../README.md)