mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-22 08:07:38 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
300 lines
8.5 KiB
Markdown
300 lines
8.5 KiB
Markdown
# GEMM Python Examples
|
|
|
|
CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.
|
|
|
|
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
|
|
|
|
## Quick Start
|
|
|
|
### Build Library
|
|
|
|
```bash
|
|
cd /path/to/composable_kernel/dispatcher
|
|
mkdir -p build && cd build
|
|
|
|
cmake .. \
|
|
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
|
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
|
-DBUILD_DISPATCHER_EXAMPLES=ON
|
|
|
|
# Build Python library (kernels generated automatically)
|
|
make dispatcher_gemm_lib -j$(nproc)
|
|
```
|
|
|
|
### Run Examples
|
|
|
|
```bash
|
|
cd /path/to/composable_kernel/dispatcher
|
|
|
|
python3 examples/gemm/python/01_basic_gemm.py
|
|
python3 examples/gemm/python/04_validation.py
|
|
python3 examples/gemm/python/07_stress_test.py
|
|
python3 examples/gemm/python/08_heuristics.py
|
|
```
|
|
|
|
## Examples
|
|
|
|
| Example | Description |
|
|
|---------|-------------|
|
|
| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support |
|
|
| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations |
|
|
| [03_benchmark.py](03_benchmark.py) | Performance benchmarking |
|
|
| [04_validation.py](04_validation.py) | CPU reference validation |
|
|
| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration |
|
|
| [06_json_export.py](06_json_export.py) | Registry JSON export |
|
|
| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing |
|
|
| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection |
|
|
| [09_multi_registry.py](09_multi_registry.py) | Multiple registries |
|
|
| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control |
|
|
| [11_json_import.py](11_json_import.py) | Import kernels from JSON |
|
|
|
|
## Example Details
|
|
|
|
### 01_basic_gemm.py - Basic GEMM
|
|
Demonstrates the Python API with multi-kernel support:
|
|
|
|
```python
|
|
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
|
|
|
# Define multiple kernel configurations
|
|
kernels = [
|
|
KernelConfig(
|
|
tile_m=128, tile_n=128, tile_k=32,
|
|
wave_m=2, wave_n=2, wave_k=1,
|
|
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
|
pipeline="compv3", scheduler="intrawave"
|
|
),
|
|
KernelConfig(
|
|
tile_m=256, tile_n=256, tile_k=32,
|
|
wave_m=2, wave_n=2, wave_k=1,
|
|
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
|
pipeline="compv4", scheduler="intrawave"
|
|
),
|
|
]
|
|
|
|
# Display configurations
|
|
print_kernel_config_table(kernels)
|
|
|
|
# Set up dispatcher with all kernels
|
|
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
|
|
|
# Run GEMM
|
|
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
|
```
|
|
|
|
### 02_batch_gemm.py - Batch GEMM
|
|
Batched matrix multiplication:
|
|
- Multiple independent GEMM operations
|
|
- Batch dimension handling
|
|
|
|
### 03_benchmark.py - Benchmarking
|
|
Performance measurement:
|
|
- GPU timing
|
|
- TFLOPS calculation
|
|
- Multiple iterations
|
|
|
|
### 04_validation.py - Validation
|
|
Correctness verification:
|
|
- NumPy reference implementation
|
|
- Tolerance-based validation
|
|
- Error reporting
|
|
|
|
### 05_numpy_integration.py - NumPy Integration
|
|
Seamless NumPy integration:
|
|
- NumPy arrays to GPU buffers
|
|
- Results back to NumPy
|
|
- Automatic type conversion
|
|
|
|
### 06_json_export.py - JSON Export
|
|
Registry serialization for tool integration:
|
|
- Export kernel configurations
|
|
- Machine-readable format
|
|
|
|
### 07_stress_test.py - Stress Testing
|
|
Comprehensive multi-kernel stress testing:
|
|
|
|
```python
|
|
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
|
|
|
# Define 48 unique kernel configurations
|
|
kernels = [
|
|
KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
|
|
KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
|
|
KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
|
|
# ... many more configurations
|
|
]
|
|
|
|
# Test each kernel
|
|
for i, kernel in enumerate(kernels):
|
|
lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
|
|
result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel
|
|
print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")
|
|
```
|
|
|
|
**Features:**
|
|
- 48 unique kernel configurations
|
|
- Various tile sizes, pipelines, and schedulers
|
|
- Per-kernel validation with unique random seeds
|
|
- Performance reporting
|
|
|
|
### 08_heuristics.py - Heuristic Selection
|
|
Custom kernel selection based on problem characteristics:
|
|
|
|
```python
|
|
# Define kernel pools for different strategies
|
|
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
|
|
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
|
|
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
|
|
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]
|
|
|
|
# Size-based heuristic
|
|
def size_based_heuristic(M, N, K):
|
|
if M * N < 512 * 512:
|
|
return SMALL_KERNELS
|
|
else:
|
|
return LARGE_KERNELS
|
|
|
|
# Strategy-based selection
|
|
def compute_strategy():
|
|
return COMPUTE_KERNELS # Optimized for compute-bound problems
|
|
|
|
def memory_strategy():
|
|
return MEMORY_KERNELS # Optimized for memory-bound problems
|
|
|
|
# Test different strategies
|
|
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
|
|
kernels = strategy(M, N, K)
|
|
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
|
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
|
```
|
|
|
|
**Features:**
|
|
- 24 kernel configurations across 6 categories
|
|
- Size-based heuristic (small vs large)
|
|
- Optimization strategies (compute, memory, latency)
|
|
- Performance comparison across strategies
|
|
|
|
### 09_multi_registry.py - Multiple Registries
|
|
Separate registries for different workloads:
|
|
- Compute-optimized registry
|
|
- Latency-optimized registry
|
|
- Dynamic registry selection
|
|
|
|
### 10_advanced_benchmark.py - Advanced Benchmark
|
|
Full control over benchmark parameters:
|
|
- Warmup iterations
|
|
- Benchmark iterations
|
|
- Statistical analysis
|
|
|
|
### 11_json_import.py - JSON Import
|
|
Import kernel configurations from JSON:
|
|
- External configuration files
|
|
- Dynamic kernel loading
|
|
|
|
## Utility Module: ctypes_utils.py
|
|
|
|
```python
|
|
from ctypes_utils import (
|
|
KernelConfig, # Single kernel configuration
|
|
setup_gemm_dispatcher, # Set up dispatcher with kernels
|
|
print_kernel_config_table, # Display kernel configurations
|
|
Dispatcher, # High-level dispatcher
|
|
Registry, # Kernel registry
|
|
Validator, # Validation utilities
|
|
)
|
|
```
|
|
|
|
### KernelConfig
|
|
|
|
```python
|
|
config = KernelConfig(
|
|
# Tile sizes
|
|
tile_m=256, tile_n=256, tile_k=32,
|
|
# Wave configuration
|
|
wave_m=2, wave_n=2, wave_k=1,
|
|
# Warp tile sizes
|
|
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
|
# Pipeline and scheduler
|
|
pipeline="compv4", # "compv3" or "compv4"
|
|
scheduler="intrawave", # "intrawave" or "interwave"
|
|
# Optional
|
|
epilogue="default",
|
|
padding=True,
|
|
double_buffer=True,
|
|
)
|
|
```
|
|
|
|
### setup_gemm_dispatcher
|
|
|
|
```python
|
|
# Single kernel
|
|
lib, dispatcher, registry = setup_gemm_dispatcher(config)
|
|
|
|
# Multiple kernels
|
|
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])
|
|
|
|
# With auto-rebuild
|
|
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)
|
|
```
|
|
|
|
### print_kernel_config_table
|
|
|
|
```python
|
|
kernels = [config1, config2, config3]
|
|
print_kernel_config_table(kernels)
|
|
# Output:
|
|
# +----+-------+-------+-------+--------+-----------+
|
|
# | # | Tile | Wave | Warp | Pipe | Scheduler |
|
|
# +----+-------+-------+-------+--------+-----------+
|
|
# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
|
|
# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
|
|
# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
|
|
# +----+-------+-------+-------+--------+-----------+
|
|
```
|
|
|
|
### GPU Memory Management
|
|
|
|
```python
|
|
import ctypes
|
|
import numpy as np
|
|
|
|
# Load HIP library
|
|
hip = ctypes.CDLL("libamdhip64.so")
|
|
|
|
# Allocate GPU memory
|
|
gpu_ptr = ctypes.c_void_p()
|
|
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)
|
|
|
|
# Copy to GPU (1 = hipMemcpyHostToDevice)
|
|
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)
|
|
|
|
# Copy back (2 = hipMemcpyDeviceToHost)
|
|
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)
|
|
|
|
# Free
|
|
hip.hipFree(gpu_ptr)
|
|
```
|
|
|
|
## Performance Testing
|
|
|
|
Test compilation performance with different kernel counts:
|
|
|
|
```bash
|
|
# Test with 10 kernels (~15s compile time)
|
|
python3 01_basic_gemm.py --num-kernels 10
|
|
|
|
# Test with 20 kernels (~25s compile time)
|
|
python3 01_basic_gemm.py --num-kernels 20
|
|
|
|
# Test with 48 kernels (~50s compile time)
|
|
python3 01_basic_gemm.py --num-kernels 48
|
|
```
|
|
|
|
Compilation time scales roughly linearly with kernel count.
|
|
|
|
## Related Documentation
|
|
|
|
- [C++ GEMM Examples](../cpp/README.md)
|
|
- [Python Conv Examples](../../conv/python/README.md)
|
|
- [Main Dispatcher README](../../../README.md)
|